From 604cfd4c7aae0028a10e5d45e17a3ec3e5f978b2 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Thu, 16 Oct 2025 22:33:41 +0000 Subject: [PATCH 001/130] bmm_fp8 --- flashinfer/gemm.py | 71 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 63a2f7e211..e2871f8e02 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -381,10 +381,6 @@ def fp8_gemm_sm100( if CUDNN_AVAILABLE and "cudnn" in runner_names: runners.append(_cudnn_gemm_fp8_runner()) - if len(runners) == 0: - major, minor = get_compute_capability(torch.device("cuda")) - raise ValueError(f"No valid runner found for current device sm{major}{minor}") - tuner = AutoTuner.get() a_tensor_index = 0 out_tensor_index = 4 @@ -2009,6 +2005,70 @@ def mm_fp4( return out +def _check_bmm_fp8_problem_size( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + _validate_fp8_output_dtype(dtype) + return True + + +@supported_compute_capability([89, 90, 100, 103, 120]) +def _cudnn_bmm_fp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + _check_cudnn_availability() + return True + + +@supported_compute_capability([89, 90, 100, 103, 120]) +def _cublas_bmm_fp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + return True + + +@supported_compute_capability([100, 103, 110, 120, 121]) +def _cutlass_bmm_fp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: + raise ValueError("e5m2 is not supported for bmm_fp8 with cutlass backend") + return True + + +@backend_requirement( + { + "cudnn": _cudnn_bmm_fp8_requirement, + "cublas": _cublas_bmm_fp8_requirement, + "cutlass": _cutlass_bmm_fp8_requirement, + "auto": _cublas_bmm_fp8_requirement, # cublas default + }, + common_check=_check_bmm_fp8_problem_size, +) def bmm_fp8( A: torch.Tensor, B: torch.Tensor, @@ -2073,7 +2133,6 @@ def bmm_fp8( >>> out.dtype torch.bfloat16 """ - _validate_fp8_output_dtype(dtype) if out is None: out = torch.empty( @@ -2091,8 +2150,6 @@ def bmm_fp8( elif backend == "cublas": backends = ["cublas"] elif backend == "cutlass": - if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: - raise ValueError("e5m2 is not supported for cutlass backend") backends = ["cutlass"] elif backend == "auto": backends = ["cutlass", "cublas", "cudnn"] From 499dcc5b5b640ec1c3c5f20f30a3a412273b5ca4 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Thu, 16 Oct 2025 22:33:41 +0000 Subject: [PATCH 002/130] bmm_fp8 --- flashinfer/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index e2871f8e02..9102324b54 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -349,7 +349,7 @@ def forward( cutlass_fp8_gemm_runner=cutlass_fp8_gemm_runner, ) - +# This is just helper for bmm_fp8 def fp8_gemm_sm100( a: torch.Tensor, b: torch.Tensor, From ad39f6725825736a6169e0311c362b0a07842161 Mon Sep 17 00:00:00 2001 From: jimmzhou Date: Wed, 29 Oct 2025 00:01:29 +0000 Subject: [PATCH 003/130] gemm.py refactor --- flashinfer/deep_gemm.py | 155 +++++++++++--- flashinfer/gemm.py | 440 ++++++++++++++++++++++++++++++---------- 2 files changed, 455 insertions(+), 140 deletions(-) diff --git a/flashinfer/deep_gemm.py b/flashinfer/deep_gemm.py index 4da91750fd..0178a4d174 100644 --- a/flashinfer/deep_gemm.py +++ b/flashinfer/deep_gemm.py @@ -45,7 +45,12 @@ from .cuda_utils import checkCudaErrors from .jit.cubin_loader import get_cubin from .jit.env import FLASHINFER_CUBIN_DIR -from .utils import ceil_div, round_up +from .utils import ( + ceil_div, + round_up, + supported_compute_capability, + backend_requirement, +) class GemmType(enum.Enum): @@ -1358,24 +1363,27 @@ def m_grouped_fp8_gemm_nt_masked_sm10x( runtime(**all_kwargs) -def m_grouped_fp8_gemm_nt_contiguous( +@supported_compute_capability([100, 103]) +def _check_group_deepgemm_fp8_nt_contiguous_problem_size( a_fp8: Tuple[torch.Tensor, torch.Tensor], b_fp8: Tuple[torch.Tensor, torch.Tensor], d: torch.Tensor, m_indices: torch.Tensor, recipe: Optional[Tuple[int, int, int]] = None, compiled_dims: str = "nk", -) -> None: - # Compiled dims can be upper cases - compiled_dims = compiled_dims.lower() - +) -> bool: # NOTES: shape must be `[M, K] @ [G, N, K].mT` major_a = get_major_type_ab(a_fp8[0]) major_b = get_major_type_ab(b_fp8[0]) - assert major_a == MajorTypeAB.KMajor - if must_be_k_major(): - assert major_b == MajorTypeAB.KMajor - assert m_indices.is_contiguous() + if major_a != MajorTypeAB.KMajor: + raise ValueError(f"major_a must be KMajor, but got {major_a}") + if must_be_k_major() and (major_b != MajorTypeAB.KMajor): + raise ValueError(f"major_b must be KMajor, but got {major_b}") + + if not m_indices.is_contiguous(): + raise ValueError( + f"m_indices must be contiguous, but got {m_indices.is_contiguous()}" + ) a, sfa = a_fp8 b, sfb = b_fp8 @@ -1385,15 +1393,48 @@ def m_grouped_fp8_gemm_nt_contiguous( m__ = m_indices.numel() # Type and shape checks - assert m == m_ == m__ and n == n_ and k == k_ - assert n > 0 and k > 0 and num_groups > 0 - assert a.dtype == torch.float8_e4m3fn - assert b.dtype == torch.float8_e4m3fn - assert d.dtype == torch.bfloat16 - assert m_indices.dtype == torch.int32 + if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__: + raise ValueError( + f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}" + ) + if a.dtype != torch.float8_e4m3fn: + raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}") + if b.dtype != torch.float8_e4m3fn: + raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}") + if d.dtype != torch.bfloat16: + raise ValueError(f"d must be bfloat16, but got {d.dtype}") + if m_indices.dtype != torch.int32: + raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}") # D must be N-major - assert get_major_type_cd(d) == MajorTypeCD.NMajor + if get_major_type_cd(d) != MajorTypeCD.NMajor: + raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}") + + return True + + +@backend_requirement( + {}, + common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, +) +def m_grouped_fp8_gemm_nt_contiguous( + a_fp8: Tuple[torch.Tensor, torch.Tensor], + b_fp8: Tuple[torch.Tensor, torch.Tensor], + d: torch.Tensor, + m_indices: torch.Tensor, + recipe: Optional[Tuple[int, int, int]] = None, + compiled_dims: str = "nk", +) -> None: + # Compiled dims can be upper cases + compiled_dims = compiled_dims.lower() + + major_a = get_major_type_ab(a_fp8[0]) + major_b = get_major_type_ab(b_fp8[0]) + + a, sfa = a_fp8 + b, sfb = b_fp8 + m, k = a.shape + num_groups, n, k_ = b.shape # Do nothing if the problem is empty if m == 0: @@ -1423,6 +1464,72 @@ def m_grouped_fp8_gemm_nt_contiguous( impl(a, sfa, b, sfb, d, m_indices) +@supported_compute_capability([100, 103]) +def _check_m_grouped_fp8_gemm_nt_masked_problem_size( + a_fp8: Tuple[torch.Tensor, torch.Tensor], + b_fp8: Tuple[torch.Tensor, torch.Tensor], + d: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, + recipe: Optional[Tuple[int, int, int]] = None, + compiled_dims: str = "nk", +) -> bool: + major_a = get_major_type_ab(a_fp8[0]) + major_b = get_major_type_ab(b_fp8[0]) + if major_a != MajorTypeAB.KMajor: + raise ValueError(f"major_a must be KMajor, but got {major_a}") + if major_b != MajorTypeAB.KMajor: + raise ValueError(f"major_b must be KMajor, but got {major_b}") + + if not masked_m.is_contiguous(): + raise ValueError( + f"masked_m must be contiguous, but got {masked_m.is_contiguous()}" + ) + + a, sfa = a_fp8 + b, sfb = b_fp8 + num_groups, m, k = a.shape + num_groups_, n, k_ = b.shape + num_groups__, m_, n_ = d.shape + num_groups___ = masked_m.numel() + + # Type and shape checks + if ( + num_groups != num_groups_ + or num_groups != num_groups__ + or num_groups != num_groups___ + ): + raise ValueError( + f"num_groups mismatch. num_groups = {num_groups}, num_groups_ = {num_groups_}, num_groups__ = {num_groups__}, num_groups___ = {num_groups___}" + ) + if m != m_ or n != n_ or k != k_: + raise ValueError( + f"m, n, k mismatch. m = {m}, m_ = {m_}, n = {n}, n_ = {n_}, k = {k}, k_ = {k_}" + ) + if expected_m <= 0 or m <= 0 or n <= 0 or k <= 0 or num_groups <= 0: + raise ValueError( + f"expected_m, m, n, k, num_groups must be greater than 0, but got expected_m = {expected_m}, m = {m}, n = {n}, k = {k}, num_groups = {num_groups}" + ) + if a.dtype != torch.float8_e4m3fn: + raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}") + if b.dtype != torch.float8_e4m3fn: + raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}") + if d.dtype != torch.bfloat16: + raise ValueError(f"d must be bfloat16, but got {d.dtype}") + if masked_m.dtype != torch.int32: + raise ValueError(f"masked_m must be int32, but got {masked_m.dtype}") + + # D must be N-major + if get_major_type_cd(d) != MajorTypeCD.NMajor: + raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}") + + return True + + +@backend_requirement( + {}, + common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size, +) def m_grouped_fp8_gemm_nt_masked( a_fp8: Tuple[torch.Tensor, torch.Tensor], b_fp8: Tuple[torch.Tensor, torch.Tensor], @@ -1445,20 +1552,6 @@ def m_grouped_fp8_gemm_nt_masked( b, sfb = b_fp8 num_groups, m, k = a.shape num_groups_, n, k_ = b.shape - num_groups__, m_, n_ = d.shape - num_groups___ = masked_m.numel() - - # Type and shape checks - assert num_groups == num_groups_ == num_groups__ == num_groups___ - assert m == m_ and n == n_ and k == k_ - assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 - assert a.dtype == torch.float8_e4m3fn - assert b.dtype == torch.float8_e4m3fn - assert d.dtype == torch.bfloat16 - assert masked_m.dtype == torch.int32 - - # D must be N-major - assert get_major_type_cd(d) == MajorTypeCD.NMajor # Transform SFA and SFB into compute-required layout recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 9102324b54..c9f61b6d92 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -349,6 +349,7 @@ def forward( cutlass_fp8_gemm_runner=cutlass_fp8_gemm_runner, ) + # This is just helper for bmm_fp8 def fp8_gemm_sm100( a: torch.Tensor, @@ -2005,7 +2006,8 @@ def mm_fp4( return out -def _check_bmm_fp8_problem_size( +@supported_compute_capability([89, 90, 100, 103, 120]) +def _cudnn_bmm_fp8_requirement( A: torch.Tensor, B: torch.Tensor, A_scale: torch.Tensor, @@ -2014,12 +2016,12 @@ def _check_bmm_fp8_problem_size( out: Optional[torch.Tensor] = None, backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", ): - _validate_fp8_output_dtype(dtype) + _check_cudnn_availability() return True @supported_compute_capability([89, 90, 100, 103, 120]) -def _cudnn_bmm_fp8_requirement( +def _cublas_bmm_fp8_requirement( A: torch.Tensor, B: torch.Tensor, A_scale: torch.Tensor, @@ -2028,12 +2030,11 @@ def _cudnn_bmm_fp8_requirement( out: Optional[torch.Tensor] = None, backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", ): - _check_cudnn_availability() return True -@supported_compute_capability([89, 90, 100, 103, 120]) -def _cublas_bmm_fp8_requirement( +@supported_compute_capability([100, 103, 110, 120, 121]) +def _cutlass_bmm_fp8_requirement( A: torch.Tensor, B: torch.Tensor, A_scale: torch.Tensor, @@ -2042,11 +2043,12 @@ def _cublas_bmm_fp8_requirement( out: Optional[torch.Tensor] = None, backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", ): + if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: + raise ValueError("e5m2 is not supported for bmm_fp8 with cutlass backend") return True -@supported_compute_capability([100, 103, 110, 120, 121]) -def _cutlass_bmm_fp8_requirement( +def _check_bmm_fp8_problem_size( A: torch.Tensor, B: torch.Tensor, A_scale: torch.Tensor, @@ -2055,8 +2057,7 @@ def _cutlass_bmm_fp8_requirement( out: Optional[torch.Tensor] = None, backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", ): - if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: - raise ValueError("e5m2 is not supported for bmm_fp8 with cutlass backend") + _validate_fp8_output_dtype(dtype) return True @@ -2160,6 +2161,78 @@ def bmm_fp8( return out +@supported_compute_capability([100, 103, 120, 121]) +def _cutlass_gemm_fp8_nt_groupwise_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + scale_major_mode: Optional[Literal["MN", "K"]] = None, + mma_sm: int = 1, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + backend: Literal["cutlass", "trtllm"] = "cutlass", +): + if scale_major_mode is None: + raise ValueError("scale_major_mode is required in CUTLASS") + + return True + + +@supported_compute_capability([100, 103]) +def _trtllm_gemm_fp8_nt_groupwise_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + scale_major_mode: Optional[Literal["MN", "K"]] = None, + mma_sm: int = 1, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + backend: Literal["cutlass", "trtllm"] = "cutlass", +): + if scale_granularity_mnk != (1, 128, 128): + raise ValueError("scale_granularity_mnk must be (1, 128, 128) in TRTLLM") + if a.shape[1] < 256: + raise ValueError("a.shape[1] must be >= 256 in TRTLLM") + + return True + + +def _check_gemm_fp8_nt_groupwise_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + scale_major_mode: Optional[Literal["MN", "K"]] = None, + mma_sm: int = 1, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + backend: Literal["cutlass", "trtllm"] = "cutlass", +): + if a.ndim != 2 or b.ndim != 2: + raise ValueError(f"Shape mismatch. a.shape = {a.shape}, b.shape = {b.shape}") + + if a.shape[1] != b.shape[1]: + raise ValueError( + f"Shape mismatch. a.shape[1] = {a.shape[1]}, b.shape[1] = {b.shape[1]}" + ) + + _validate_fp8_output_dtype(out_dtype) + + return True + + +@backend_requirement( + { + "cutlass": _cutlass_gemm_fp8_nt_groupwise_requirement, + "trtllm": _trtllm_gemm_fp8_nt_groupwise_requirement, + }, + common_check=_check_gemm_fp8_nt_groupwise_problem_size, +) def gemm_fp8_nt_groupwise( a: torch.Tensor, b: torch.Tensor, @@ -2233,27 +2306,16 @@ def gemm_fp8_nt_groupwise( ----- The ``m`` should be padded to a multiple of 4 before calling this function, to accommodate the kernel's requirement. """ - if backend == "trtllm" and _match_sm_version(a.device, ["110"]): - raise ValueError("TRTLLM FP8 GEMM is not supported on SM110.") workspace_buffer = _get_cache_buf( "gemm_fp8_nt_groupwise_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) - if a.ndim != 2 or b.ndim != 2: - raise ValueError(f"Shape mismatch. a.shape = {a.shape}, b.shape = {b.shape}") - - if a.shape[1] != b.shape[1]: - raise ValueError( - f"Shape mismatch. a.shape[1] = {a.shape[1]}, b.shape[1] = {b.shape[1]}" - ) if out is None: out_dtype = out_dtype or torch.bfloat16 else: out_dtype = out.dtype - _validate_fp8_output_dtype(out_dtype) - # NOTE(Zihao): (out_specified, need_padding) # (False, False) -> create out_padded tensor explicitly # (False, True) -> create out_padded tensor explicitly @@ -2269,18 +2331,6 @@ def gemm_fp8_nt_groupwise( ) if backend == "cutlass": - if not _match_sm_version(a.device, ["100", "103", "110", "120", "121"]): - raise ValueError( - "gemm_fp8_nt_groupwise is only supported on SM100, SM103, SM110, SM120, or SM121 in cutlass backend." - ) - elif backend == "trtllm": - if not _match_sm_version(a.device, ["100", "103"]): - raise ValueError( - "gemm_fp8_nt_groupwise is only supported on SM100, SM103 in trtllm backend." - ) - - if backend == "cutlass": - assert scale_major_mode is not None if is_sm120a_supported(a.device) or is_sm121a_supported(a.device): # SM120/121 doesn't use mma_sm parameter get_gemm_sm120_module().gemm_fp8_nt_groupwise( @@ -2308,8 +2358,6 @@ def gemm_fp8_nt_groupwise( else: raise ValueError(f"Unsupported device for FP8 GEMM: {a.device}") elif backend == "trtllm": - assert scale_granularity_mnk == (1, 128, 128) - assert a.shape[1] >= 256 # mma_sm is ignored get_trtllm_gemm_module().trtllm_gemm( workspace_buffer, @@ -2468,6 +2516,48 @@ def pad_up(x, y): ) +@supported_compute_capability([100, 103, 120, 121]) +def _check_gemm_fp8_nt_blockscaled_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + scale_major_mode: Optional[Literal["MN", "K"]] = "MN", + mma_sm: int = 1, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +): + _check_gemm_fp8_nt_groupwise_problem_size( + a, + b, + a_scale, + b_scale, + scale_major_mode, + mma_sm, + out, + out_dtype, + backend="cutlass", + ) + + _cutlass_gemm_fp8_nt_groupwise_requirement( + a, + b, + a_scale, + b_scale, + scale_major_mode, + mma_sm, + out, + out_dtype, + backend="cutlass", + ) + + return True + + +@backend_requirement( + {}, + common_check=_check_gemm_fp8_nt_blockscaled_problem_size, +) def gemm_fp8_nt_blockscaled( a: torch.Tensor, b: torch.Tensor, @@ -2496,6 +2586,79 @@ def gemm_fp8_nt_blockscaled( ) +@supported_compute_capability([100, 120, 121]) +def _check_group_gemm_fp8_nt_groupwise_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + m_indptr: torch.Tensor, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + scale_major_mode: Literal["MN", "K"] = "MN", + mma_sm: int = 1, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +): + if a.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError(f"a must be a float8 tensor, but got {a.dtype}") + if b.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError(f"b must be a float8 tensor, but got {b.dtype}") + if a_scale.dtype not in [torch.float32]: + raise ValueError(f"a_scale must be a float32 tensor, but got {a_scale.dtype}") + if b_scale.dtype not in [torch.float32]: + raise ValueError(f"b_scale must be a float32 tensor, but got {b_scale.dtype}") + if m_indptr.dtype not in [torch.int32]: + raise ValueError(f"m_indptr must be a int32 tensor, but got {m_indptr.dtype}") + if scale_major_mode not in ["MN", "K"]: + raise ValueError( + f"scale_major_mode must be either 'MN' or 'K', but got {scale_major_mode}" + ) + if mma_sm not in [1, 2]: + raise ValueError(f"mma_sm must be either 1 or 2, but got {mma_sm}") + + # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance + n = b.shape[1] + k = b.shape[2] + + if out is None: + if out_dtype is None: + out_dtype = torch.bfloat16 + else: + if out_dtype is None: + out_dtype = out.dtype + if out.shape != (a.shape[0], n): + raise ValueError( + f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}" + ) + if out.dtype != out_dtype: + raise ValueError( + f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}" + ) + + _validate_fp8_output_dtype(out_dtype) + + if a.shape[1] != k: + raise ValueError(f"Shape mismatch. a.shape[1] = {a.shape[1]}, k = {k}") + if n % 8 != 0: + raise ValueError(f"n must be a multiple of 8, but got {n}") + if k % 16 != 0: + raise ValueError(f"k must be a multiple of 16, but got {k}") + + num_groups = m_indptr.shape[0] - 1 + + if is_sm120a_supported(a.device) or is_sm121a_supported(a.device): + if num_groups > 1: + raise RuntimeError( + "group_gemm_fp8_nt_groupwise has correctness issues for num_groups > 1 on SM120/121" + ) + + return True + + +@backend_requirement( + {}, + common_check=_check_group_gemm_fp8_nt_groupwise_problem_size, +) def group_gemm_fp8_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k) @@ -2560,19 +2723,6 @@ def group_gemm_fp8_nt_groupwise( Each value in ``m_indptr`` should be padded to a multiple of 4 before calling this function, to accommodate the kernel's requirement. """ - if ( - not is_sm100a_supported(a.device) - and not is_sm120a_supported(a.device) - and not is_sm121a_supported(a.device) - ): - raise ValueError( - "gemm_fp8_nt_groupwise is only supported on SM100, SM120, and SM121." - ) - if not (_match_sm_version(a.device, ["100", "103", "110", "120", "121"])): - raise ValueError( - "gemm_fp8_nt_groupwise is only supported on SM100, SM103, SM110, SM120, or SM121." - ) - int_workspace_buffer = _get_cache_buf( "group_gemm_fp8_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) @@ -2580,46 +2730,21 @@ def group_gemm_fp8_nt_groupwise( "group_gemm_fp8_nt_groupwise_float_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) - assert a.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - assert b.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - assert a_scale.dtype == torch.float32 - assert b_scale.dtype == torch.float32 - assert m_indptr.dtype == torch.int32 - assert scale_major_mode in ["MN", "K"] - assert mma_sm in [1, 2] if out is None: if out_dtype is None: out_dtype = torch.bfloat16 else: if out_dtype is None: out_dtype = out.dtype - _validate_fp8_output_dtype(out_dtype) - num_groups = m_indptr.shape[0] - 1 - assert b.shape[0] == num_groups n = b.shape[1] k = b.shape[2] - # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance - assert a.shape[1] == k - align_n = 8 - align_k = 16 - assert n % align_n == 0 - assert k % align_k == 0 - out_shape = (a.shape[0], n) if out is None: out = torch.empty(out_shape, dtype=out_dtype, device=a.device) - else: - assert out.shape == out_shape - assert out.dtype == out_dtype if is_sm120a_supported(a.device) or is_sm121a_supported(a.device): - # it has correctness issues for num_groups > 1 - if num_groups > 1: - raise RuntimeError( - "group_gemm_fp8_nt_groupwise has correctness issues for num_groups > 1 on SM120/121" - ) # SM120/121 doesn't use mma_sm parameter get_gemm_sm120_module().group_gemm_fp8_nt_groupwise( int_workspace_buffer, @@ -2651,13 +2776,96 @@ def group_gemm_fp8_nt_groupwise( scale_major_mode, mma_sm, ) + return out + + +@supported_compute_capability([100, 103, 110, 120, 121]) +def _check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + m_indptr: torch.Tensor, + mma_sm: int = 1, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + swap_ab: bool = True, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +): + if a.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError( + f"a must be a float8_e4m3fn or float8_e5m2 tensor, but got {a.dtype}" + ) + if b.dtype != torch.uint8: + raise ValueError(f"b must be a uint8 tensor, but got {b.dtype}") + if a_scale.dtype != torch.uint8: + raise ValueError(f"a_scale must be a uint8 tensor, but got {a_scale.dtype}") + if b_scale.dtype != torch.uint8: + raise ValueError(f"b_scale must be a uint8 tensor, but got {b_scale.dtype}") + if m_indptr.dtype != torch.int32: + raise ValueError(f"m_indptr must be a int32 tensor, but got {m_indptr.dtype}") + if mma_sm not in [1, 2]: + raise ValueError(f"mma_sm must be either 1 or 2, but got {mma_sm}") + if tile_m not in [128]: + raise ValueError(f"tile_m must be 128, but got {tile_m}") + if tile_n not in [64, 128, 192, 256]: + raise ValueError(f"tile_n must be one of [64, 128, 192, 256], but got {tile_n}") + if tile_k not in [128, 256]: + raise ValueError(f"tile_k must be either 128 or 256, but got {tile_k}") + if swap_ab not in [True, False]: + raise ValueError(f"swap_ab must be a boolean value, but got {swap_ab}") + + # Determine out_dtype if not specified + if out is None: + if out_dtype is None: + out_dtype = torch.bfloat16 else: + if out_dtype is None: + out_dtype = out.dtype + + if out_dtype not in [torch.bfloat16, torch.float16]: raise ValueError( - f"group_gemm_fp8_nt_groupwise requires SM100, SM120, or SM121, but got {a.device}" + f"out_dtype must be either torch.bfloat16 or torch.float16, but got {out_dtype}" ) - return out + num_groups = m_indptr.shape[0] - 1 + if b.shape[0] != num_groups: + raise ValueError( + f"b.shape[0] must equal num_groups (m_indptr.shape[0] - 1), but got b.shape[0]={b.shape[0]}, num_groups={num_groups}" + ) + + n = b.shape[1] + k = b.shape[2] * 2 # Multiply by 2 because b is e2m1 packed as uint8 + + # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance + if a.shape[1] != k: + raise ValueError( + f"a.shape[1] must equal k, but got a.shape[1]={a.shape[1]}, k={k}" + ) + align_n = 8 + align_k = 128 + if n % align_n != 0: + raise ValueError(f"n must be a multiple of {align_n}, but got n={n}") + if k % align_k != 0: + raise ValueError(f"k must be a multiple of {align_k}, but got k={k}") + + out_shape = (a.shape[0], n) + if out is not None: + if out.shape != out_shape: + raise ValueError(f"out.shape must be {out_shape}, but got {out.shape}") + if out.dtype != out_dtype: + raise ValueError(f"out.dtype must be {out_dtype}, but got {out.dtype}") + + return True + + +@backend_requirement( + {}, + common_check=_check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size, +) def group_gemm_mxfp8_mxfp4_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k // 2) @@ -2734,43 +2942,20 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( DEFAULT_WORKSPACE_SIZE, a.device, ) - - assert a.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - assert b.dtype == torch.uint8 - assert a_scale.dtype == torch.uint8 - assert b_scale.dtype == torch.uint8 - assert m_indptr.dtype == torch.int32 - assert mma_sm in [1, 2] - assert tile_m in [128] - assert tile_n in [64, 128, 192, 256] - assert tile_k in [128, 256] - assert swap_ab in [True, False] + # Determine out_dtype if not specified if out is None: if out_dtype is None: out_dtype = torch.bfloat16 else: if out_dtype is None: out_dtype = out.dtype - assert out_dtype in [torch.bfloat16, torch.float16] - num_groups = m_indptr.shape[0] - 1 - assert b.shape[0] == num_groups n = b.shape[1] k = b.shape[2] * 2 # Multiply by 2 because b is e2m1 packed as uint8 - # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance - assert a.shape[1] == k - align_n = 8 - align_k = 128 - assert n % align_n == 0 - assert k % align_k == 0 - out_shape = (a.shape[0], n) if out is None: out = torch.empty(out_shape, dtype=out_dtype, device=a.device) - else: - assert out.shape == out_shape - assert out.dtype == out_dtype get_gemm_sm100_module().group_gemm_mxfp4_nt_groupwise( int_workspace_buffer, @@ -2825,6 +3010,30 @@ def get_deepgemm_sm100_module(): return module +@supported_compute_capability([100, 103]) +def _check_group_deepgemm_fp8_nt_groupwise_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + m_indices: torch.Tensor, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> bool: + from flashinfer.deep_gemm import ( + _check_group_deepgemm_fp8_nt_contiguous_problem_size, + ) + + return _check_group_deepgemm_fp8_nt_contiguous_problem_size( + (a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk + ) + + +@backend_requirement( + {}, + common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size, +) def group_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (m, k) b: torch.Tensor, # (batch_size, n, k) @@ -2939,11 +3148,6 @@ def group_deepgemm_fp8_nt_groupwise( """ from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_contiguous - if not _match_sm_version(a.device, ["100", "103"]): - raise ValueError( - "m_grouped_fp8_gemm_nt_contiguous is only supported on SM100, SM100, SM103." - ) - if out is None: out_dtype = out_dtype or torch.bfloat16 out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device) @@ -2955,6 +3159,29 @@ def group_deepgemm_fp8_nt_groupwise( return out +@supported_compute_capability([100, 103]) +def _check_batch_deepgemm_fp8_nt_groupwise( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, + scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> bool: + from flashinfer.deep_gemm import _check_m_grouped_fp8_gemm_nt_masked_problem_size + + return _check_m_grouped_fp8_gemm_nt_masked_problem_size( + (a, a_scale), (b, b_scale), out, masked_m, expected_m, scale_granularity_mnk + ) + + +@backend_requirement( + {}, + common_check=_check_batch_deepgemm_fp8_nt_groupwise, +) def batch_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (batch_size, m, k) b: torch.Tensor, # (batch_size, n, k) @@ -3072,11 +3299,6 @@ def batch_deepgemm_fp8_nt_groupwise( """ from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_masked - if not _match_sm_version(a.device, ["100", "103"]): - raise ValueError( - "m_grouped_fp8_gemm_nt_masked is only supported on SM100, SM103." - ) - if out is None: out_dtype = out_dtype or torch.bfloat16 out = torch.empty( From ebb610c8f37d0db0daa780fdcc8f400229427550 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Tue, 28 Oct 2025 23:10:26 -0700 Subject: [PATCH 004/130] unittest: Add head dim 256 test cases and mark as xfail (#1999) --- tests/attention/test_trtllm_gen_attention.py | 259 +++++++++++++++---- 1 file changed, 211 insertions(+), 48 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 7ce086a6ac..f14c57b1f1 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -577,47 +577,7 @@ def test_trtllm_batch_prefill_bs1( ) -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND -@pytest.mark.parametrize( - "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", - [ - (4, 1, 16, 2, 1), - (4, 1, 32, 2, 5), - (4, 2, 64, 2, 5), - (4, 3, 32, 2, 5), - (4, 3, 64, 2, 1), - (4, 4, 64, 4, 1), - (4, 5, 64, 4, 8), - (128, 1, 64, 2, 5), - (128, 2, 32, 4, 1), - (128, 3, 16, 4, 8), - (128, 4, 16, 2, 5), - (128, 5, 16, 2, 5), - (256, 1, 64, 4, 8), - (256, 2, 16, 2, 8), - (256, 3, 64, 4, 5), - (256, 4, 32, 2, 8), - (256, 5, 32, 2, 1), - ], -) -@pytest.mark.parametrize("window_left", [-1, 127]) -@pytest.mark.parametrize( - "q_dtype,kv_dtype,o_dtype", - [ - ("bf16", "bf16", "bf16"), - ("fp16", "fp16", "fp16"), - ("bf16", "fp8", "bf16"), - ("fp16", "fp8", "fp16"), - ("fp8", "fp8", "bf16"), - ("fp8", "fp8", "fp16"), - ("fp8", "fp8", "fp8"), - ("fp8", "fp8", "nvfp4"), - ], -) -@pytest.mark.parametrize("enable_pdl", [True, False, None]) -@pytest.mark.parametrize("enable_sink", [True, False]) -@pytest.mark.parametrize("max_in_kv_len", [110]) -def test_trtllm_batch_decode( +def _test_trtllm_batch_decode( kv_layout, batch_size, q_len_per_req, @@ -631,7 +591,13 @@ def test_trtllm_batch_decode( enable_pdl, enable_sink, max_in_kv_len, + head_dim, ): + """ + Common function for testing trtllm-gen decode. + + Combinations of parameters are tested in test_trtllm_batch_decode() and test_trtllm_batch_decode_...() + """ compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") @@ -642,7 +608,6 @@ def test_trtllm_batch_decode( # Set up test parameters torch.manual_seed(0) - head_dim = 128 # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size @@ -858,6 +823,82 @@ def test_trtllm_batch_decode( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 2, 1), + (4, 1, 32, 2, 5), + (4, 2, 64, 2, 5), + (4, 3, 32, 2, 5), + (4, 3, 64, 2, 1), + (4, 4, 64, 4, 1), + (4, 5, 64, 4, 8), + (128, 1, 64, 2, 5), + (128, 2, 32, 4, 1), + (128, 3, 16, 4, 8), + (128, 4, 16, 2, 5), + (128, 5, 16, 2, 5), + (256, 1, 64, 4, 8), + (256, 2, 16, 2, 8), + (256, 3, 64, 4, 5), + (256, 4, 32, 2, 8), + (256, 5, 32, 2, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1, 127]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("bf16", "fp8", "bf16"), + ("fp16", "fp8", "fp16"), + ("fp8", "fp8", "bf16"), + ("fp8", "fp8", "fp16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("head_dim", [128]) +def test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + # General set of tests for trtllm-gen decode + _test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + ) + + @pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND @pytest.mark.parametrize( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", @@ -875,6 +916,7 @@ def test_trtllm_batch_decode( @pytest.mark.parametrize("enable_pdl", [None]) @pytest.mark.parametrize("enable_sink", [False]) @pytest.mark.parametrize("max_in_kv_len", [8192]) +@pytest.mark.parametrize("head_dim", [128]) def test_trtllm_batch_decode_bs1( kv_layout, batch_size, @@ -889,9 +931,134 @@ def test_trtllm_batch_decode_bs1( enable_pdl, enable_sink, max_in_kv_len, + head_dim, ): + # Small number of test cases for batch size 1 pytest.xfail("trtllm-gen decode gets incorrect output with bs1") - test_trtllm_batch_decode( + _test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + ) + + +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 2, 1), + (4, 1, 32, 2, 5), + (4, 3, 64, 2, 1), + (4, 4, 64, 4, 1), + (128, 3, 16, 4, 8), + (128, 4, 16, 2, 5), + (256, 4, 32, 2, 8), + (256, 5, 32, 2, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("fp8", "fp8", "fp16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("head_dim", [256]) +def test_trtllm_batch_decode_head_dim_256( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + # Small number of test cases for head_dim = 256 + pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") + _test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + ) + + +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (1, 1, 16, 2, 1), + (1, 1, 32, 2, 5), + (1, 3, 64, 2, 1), + (1, 4, 64, 4, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp8", "fp8", "fp8"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_in_kv_len", [4096, 8192, 16384, 32768, 65536, 131072]) +@pytest.mark.parametrize("head_dim", [128]) +def test_trtllm_batch_decode_long_sequence_length( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + # Small number of test cases for long sequence length + pytest.xfail("trtllm-gen decode gets incorrect output with Long sequence length") + _test_trtllm_batch_decode( kv_layout, batch_size, q_len_per_req, @@ -905,6 +1072,7 @@ def test_trtllm_batch_decode_bs1( enable_pdl, enable_sink, max_in_kv_len, + head_dim, ) @@ -1053,8 +1221,3 @@ def test_trtllm_gen_prefill_deepseek_bs1( test_trtllm_gen_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal ) - - -if __name__ == "__main__": - test_trtllm_batch_prefill("HND", 128, 32, 2, 5, -1, "fp16", "fp16", "fp16", False) - test_trtllm_batch_decode("HND", 256, 3, 64, 4, 5, -1, "fp8", "fp8", "fp8", True) From bb6b62089a13db0a6b90887d68b161bbc1b5fc8e Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Tue, 28 Oct 2025 23:47:25 -0700 Subject: [PATCH 005/130] feat: autotune tile_tokens_dim in trtllm-gen MOE (#1980) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description - Update the autotune logic in trtllm-gen moe. Instead of using a fixed `tile_tokens_dim`, tune in a range of `[max(8,tile_token_dim/2), tile_token_dim, min(128, tile_token_dim*2), min(128, tile_token_dim*4)]` - Add FP8 MOE autotune logic, initial PR https://github.com/flashinfer-ai/flashinfer/pull/1494 from @aleozlx, update logic to sync with new autotuner. - Update logic in `test_trtllm_gen_fused_moe.py`. - Update the `conftest.py` to speed up test, previously use `try_first` which introduce duplicate run - Add log_once in logger ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Runtime autotuning with per-tile dynamic routing and selectable MoE runner options (gated activation, shuffled-weight, weight-layout). * One-time (deduplicated) logging helpers added to JIT logger. * **Deprecations** * tile_tokens_dim removed from new paths and marked deprecated in legacy entry points; new tuning parameters introduced for autotuning. * **Tests** * Tests refactored for autotuning/routing with new helpers and improved handling/reporting for missing JIT cache. --------- Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: yzh119 --- .../bench_trtllm_gen_fused_moe_autotuner.py | 7 +- csrc/trtllm_fused_moe_kernel_launcher.cu | 320 +++++--- flashinfer/fused_moe/core.py | 364 ++++++--- flashinfer/jit/core.py | 30 +- tests/conftest.py | 4 +- tests/moe/test_trtllm_gen_fused_moe.py | 736 +++++++++++------- 6 files changed, 989 insertions(+), 472 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 952b479a1d..2a991829dd 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -11,7 +11,7 @@ from flashinfer.fused_moe import trtllm_fp4_block_scale_moe from flashinfer.autotuner import autotune from flashinfer.testing.utils import bench_gpu_time -from flashinfer.utils import device_support_pdl, calculate_tile_tokens_dim +from flashinfer.utils import device_support_pdl def bench_trtllm_gen_fused_moe_autotuner( @@ -99,9 +99,6 @@ def bench_trtllm_gen_fused_moe_autotuner( bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 - tile_tokens_dim = calculate_tile_tokens_dim( - num_tokens, num_experts, top_k, 64 if quant_mode == "MxFP4xBf16" else 128 - ) output1_scale_scalar = torch.tensor( [hidden_states_global_scale * w13_global_scale] * num_experts, device=device ) @@ -136,7 +133,7 @@ def bench_trtllm_gen_fused_moe_autotuner( 0, # local_expert_offset num_experts, None, # routed_scaling_factor - tile_tokens_dim, + None, # tile_tokens_dim RoutingMethodType.Renormalize.value, True, enable_pdl, diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 22c1e8e51e..538dc92725 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include #include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h" @@ -37,6 +39,41 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; +// Utility function to compute the next power of two +inline int32_t nextPowerOfTwo(float value) { + int32_t n = static_cast(std::ceil(value)); + if (n <= 1) return 1; + + // If n is already a power of 2, return it + if ((n & (n - 1)) == 0) return n; + + // Find the next power of 2 + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n++; + + return n; +} + +std::set computeSelectedTileN(std::vector const& supported_tile_nums, + int64_t const num_tokens, int64_t const top_k, + int64_t const num_local_experts) { + float const avg_tokens_per_expert = static_cast(num_tokens * top_k) / num_local_experts; + int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), + supported_tile_nums.front(), supported_tile_nums.back()); + + std::set selected_tile_nums = { + std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim, + std::min(supported_tile_nums.back(), tile_tokens_dim * 2), + std::min(supported_tile_nums.back(), tile_tokens_dim * 4)}; + + return selected_tile_nums; +} + void trtllm_fp8_per_tensor_scale_moe_launcher( TensorView routing_logits, Optional routing_bias, TensorView hidden_states, TensorView gemm1_weights, TensorView output1_scales_scalar, @@ -46,7 +83,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( int64_t const intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts, Optional const routed_scaling_factor, bool const use_routing_scales_on_input, int64_t const tile_tokens_dim, - int64_t const routing_method_type, bool enable_pdl) { + int64_t const routing_method_type, + tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex, + bool enable_pdl) { static const std::tuple device_props = [hidden_states] { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, @@ -124,6 +163,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; } + args.mDtypeOut = btg::Dtype::Bfloat16; // Output is always bfloat16 for fp8 per-tensor scale args.routing_logits = routing_logits.data_ptr(); auto const routing_bias_dtype = @@ -158,6 +198,13 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( args.num_tokens, top_k, num_experts, tile_tokens_dim); + int32_t max_num_padded_tokens_gemm1 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt)); + int32_t max_num_padded_tokens_gemm2 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); + Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits.device()); Tensor expanded_idx_to_permuted_idx = alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits.device()); @@ -174,20 +221,17 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( routing_logits.device()); // allocate workspace for activation/gemm/finalize kernels - // Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, - // dl_float8_e4m3fn, hidden_states.device()); - // Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size}, - // dl_float8_e4m3fn, hidden_states.device()); - Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, + Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8, + hidden_states.device()); + Tensor gemm1_output_scale = + alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, + hidden_states.device()); + Tensor activation_output = alloc_tensor({max_num_padded_tokens_gemm1, intermediate_size}, + dl_uint8, hidden_states.device()); + Tensor activation_output_scale = alloc_tensor( + {intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states.device()); + Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16, hidden_states.device()); - Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens}, - dl_float32, hidden_states.device()); - Tensor activation_output = - alloc_tensor({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states.device()); - Tensor activation_output_scale = alloc_tensor({intermediate_size / 128, max_num_padded_tokens}, - dl_float32, hidden_states.device()); - Tensor gemm2_output = - alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states.device()); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); @@ -257,7 +301,8 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( // setup workspace workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); - workspace.total_max_padded_tokens = max_num_padded_tokens; + workspace.total_max_padded_tokens = + std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); workspace.ProjUpTileN = tile_tokens_dim; workspace.routing_expert_indexes = static_cast(expert_indexes.data_ptr()); workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); @@ -283,13 +328,6 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( args.output = output.data_ptr(); args.output_scale = nullptr; - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( - args.mDtypeElt, args.mUseDeepSeekFp8, tile_tokens_dim, /*useShuffledMatrixA*/ true); - - auto const moeConfigIndex = - moe_runner.getDefaultValidConfigIndex(args.top_k, args.hidden_size, args.intermediate_size, - args.local_num_experts, args.num_tokens); - auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); Tensor workspace_fc1 = alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device()); @@ -309,16 +347,56 @@ void trtllm_fp8_per_tensor_scale_moe( TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - bool use_routing_scales_on_input, int64_t tile_tokens_dim, int64_t routing_method_type, - bool enable_pdl) { + bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl, + Array config_index) { auto dtype = hidden_states.dtype(); if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) { + using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; + + // Convert PyTorch dtype to TensorRT-LLM dtype + btg::Dtype mDtypeElt; + if (dtype == dl_float16) { + mDtypeElt = btg::Dtype::Fp16; + } else if (dtype == dl_bfloat16) { + mDtypeElt = btg::Dtype::Bfloat16; + } else if (dtype == dl_float8_e4m3fn) { + mDtypeElt = btg::Dtype::E4m3; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; + } + + auto const num_tokens = hidden_states.size(0); + auto const hidden_size = hidden_states.size(1); + bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8 + + std::vector mSupportedTileN = {8, 16, 32, 64, 128}; + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + + // Build runners for all supported tile sizes + std::unordered_map> mRunners; + for (int32_t tile_N : selected_tile_nums) { + // Always use the two-parameter constructor for consistency + mRunners.emplace(tile_N, std::make_unique(mDtypeElt, mUseDeepSeekFp8, tile_N, + /*useShuffledMatrixA*/ true)); + } + + // moeConfigIndex corresponds to pair (tile_N, config) + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + // Autotuner has requested a default or 'fallback' config index + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); + config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + local_num_experts, num_tokens); + } + trtllm_fp8_per_tensor_scale_moe_launcher( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, output, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, - routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim, routing_method_type, - enable_pdl); + routed_scaling_factor, use_routing_scales_on_input, tile_N, routing_method_type, + *mRunners[tile_N], config, enable_pdl); } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype."; } @@ -468,10 +546,6 @@ void trtllm_fp8_block_scale_moe_launcher( routing_logits.device()); // allocate workspace for activation/gemm/finalize kernels - // Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, - // dl_float8_e4m3fn, hidden_states.device()); - // Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size}, - // dl_float8_e4m3fn, hidden_states.device()); Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8, hidden_states.device()); Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens}, @@ -623,16 +697,14 @@ void trtllm_fp8_block_scale_moe_launcher( enable_pdl); } -void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional routing_bias, - TensorView hidden_states, TensorView hidden_states_scale, - TensorView gemm1_weights, TensorView gemm1_weights_scale, - TensorView gemm2_weights, TensorView gemm2_weights_scale, - TensorView output, int64_t num_experts, int64_t top_k, - Optional n_group, Optional topk_group, - int64_t intermediate_size, int64_t local_expert_offset, - int64_t local_num_experts, Optional routed_scaling_factor, - int64_t tile_tokens_dim, int64_t routing_method_type, - bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl) { +void trtllm_fp8_block_scale_moe( + TensorView routing_logits, Optional routing_bias, TensorView hidden_states, + TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, + TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output, + int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, + int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, + Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, bool enable_pdl, Array config_index) { auto dtype = hidden_states.dtype(); if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; @@ -643,24 +715,36 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) << "the value of weight_layout is not recognized"; - // Properly initialize the runner using make_unique like in the original code - auto mRunner = std::make_unique( - mDtypeElt, mUseDeepSeekFp8, tile_tokens_dim, use_shuffled_weight, - static_cast(weight_layout)); - - // Always use fallback config (equivalent to moeConfigIndex == -1 case from original code) auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); - int64_t moeConfigIndex = mRunner->getDefaultValidConfigIndex( - top_k, hidden_size, intermediate_size, local_num_experts, num_tokens); + std::vector mSupportedTileN = {8, 16, 32, 64}; + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + + // Build runners for all supported tile sizes + std::unordered_map> mRunners; + for (int32_t tile_N : selected_tile_nums) { + mRunners.emplace(tile_N, std::make_unique( + mDtypeElt, mUseDeepSeekFp8, tile_N, use_shuffled_weight, + static_cast(weight_layout))); + } + + // moeConfigIndex corresponds to pair (tile_N, config) + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + // Autotuner has requested a default or 'fallback' config index + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); + config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + local_num_experts, num_tokens); + } trtllm_fp8_block_scale_moe_launcher( routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, - routed_scaling_factor, tile_tokens_dim, routing_method_type, *mRunner, moeConfigIndex, - enable_pdl); + routed_scaling_factor, tile_N, routing_method_type, *mRunners[tile_N], config, enable_pdl); } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported hidden state dtype."; } @@ -845,10 +929,6 @@ Array trtllm_fp4_block_scale_moe_launcher( Tensor permuted_idx_to_token_idx = alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device()); - // Tensor expert_weights = alloc_tensor( - // {args.num_tokens, args.top_k}, dl_bfloat16, hidden_states.device()); - // Tensor expert_indexes = alloc_tensor( - // {args.num_tokens, args.top_k}, dl_int32, hidden_states.device(); int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); Tensor expert_count_histogram = alloc_tensor({size_of_expert_count_histogram}, dl_int32, hidden_states.device()); @@ -858,10 +938,6 @@ Array trtllm_fp4_block_scale_moe_launcher( // allocate workspace for activation/gemm/finalize kernels auto const gemm1_output_hidden = dtype_act == btg::Dtype::E2m1 ? intermediate_size / 2 : intermediate_size; - // Tensor gemm1_output = alloc_tensor( - // {max_num_padded_tokens, gemm1_output_hidden}, - // dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn, - // hidden_states.device()); Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden}, dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8, hidden_states.device()); @@ -1101,8 +1177,8 @@ Array trtllm_fp4_block_scale_moe( Optional output2_scales_scalar, int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - int64_t tile_tokens_dim, int64_t routing_method_type, bool do_finalize, bool enable_pdl, - int64_t gated_act_type, TensorView output, int64_t config_index) { + int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t gated_act_type, + TensorView output, Array config_index) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; int const num_tokens = hidden_states.size(0); @@ -1148,55 +1224,115 @@ Array trtllm_fp4_block_scale_moe( } bool mUseDeepSeekFp8{false}; // FP4 doesn't use DeepSeek FP8 - // Properly initialize the runner using make_unique like in the original code - auto mRunner = std::make_unique( - mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(gated_act_type), /*useShuffledMatrixA*/ true); - - if (config_index == -1) { - config_index = mRunner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, - local_num_experts, num_tokens); + std::vector mSupportedTileN = {8, 16, 32, 64}; + if (mDtypeAct != btg::Dtype::Bfloat16) { + mSupportedTileN.push_back(128); + } + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build runners for all supported tile sizes + std::unordered_map> mRunners; + for (int32_t tile_N : selected_tile_nums) { + mRunners.emplace(tile_N, + std::make_unique(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, tile_N, + static_cast(gated_act_type), + /*useShuffledMatrixA*/ true)); } + // moeConfigIndex corresponds to pair (tile_N, config) + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + // Autotuner has requested a default or 'fallback' config index + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); + config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + local_num_experts, num_tokens); + } return trtllm_fp4_block_scale_moe_launcher( routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group, - intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, do_finalize, *mRunner, mDtypeAct, mDtypeWeights, - config_index, enable_pdl, output); + intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tile_N, + routing_method_type, do_finalize, *mRunners[tile_N], mDtypeAct, mDtypeWeights, config, + enable_pdl, output); } -int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_, - int64_t const dtype_weights_, bool const useDeepSeekFp8, - int64_t const top_k, int64_t const hidden_size, - int64_t const intermediate_size, +int64_t trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const dtype_weights_, + bool const useDeepSeekFp8, int64_t const top_k, + int64_t const hidden_size, int64_t const intermediate_size, int64_t const num_local_experts, int64_t const gated_act_type, int64_t const num_tokens) { auto dtype_act = static_cast(dtype_act_); auto dtype_weights = static_cast(dtype_weights_); - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( - dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(gated_act_type), /*useShuffledMatrixA*/ true); - return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens); + std::vector supported_tile_nums = {8, 16, 32, 64}; + // Check if we should add tile size 128 + bool is_fp4_without_bf16_act = + (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) && + dtype_act != btg::Dtype::Bfloat16; + bool is_fp8_per_tensor = + dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8; + + if (is_fp4_without_bf16_act || is_fp8_per_tensor) { + supported_tile_nums.push_back(128); + } + std::set selected_tile_nums = + computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + + std::unique_ptr moe_runner = + std::make_unique( + dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin(), + static_cast(gated_act_type), /*useShuffledMatrixA*/ true); + + return moe_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); } -Array trtllm_get_valid_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_, - int64_t const dtype_weights_, bool const useDeepSeekFp8, - int64_t const top_k, int64_t const hidden_size, - int64_t const intermediate_size, - int64_t const num_local_experts, - int64_t const gated_act_type, - int64_t const num_tokens) { +Array> trtllm_get_valid_moe_configs( + int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8, + int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size, + int64_t const num_local_experts, int64_t const gated_act_type, bool const use_shuffled_weight, + int64_t const weight_layout, int64_t const num_tokens) { + // returns (tile_N, config) + Array> valid_configs; auto dtype_act = static_cast(dtype_act_); auto dtype_weights = static_cast(dtype_weights_); - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( - dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(gated_act_type), /*useShuffledMatrixA*/ true); - return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, - num_tokens); + std::vector supported_tile_nums = {8, 16, 32, 64}; + // Check if we should add tile size 128 + bool is_fp4_without_bf16_act = + (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) && + dtype_act != btg::Dtype::Bfloat16; + bool is_fp8_per_tensor = + dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8; + + if (is_fp4_without_bf16_act || is_fp8_per_tensor) { + supported_tile_nums.push_back(128); + } + std::set selected_tile_nums = + computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + + for (int32_t tile_N : selected_tile_nums) { + std::unique_ptr moe_runner; + + if (dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3) { + // FP8 block scale MOE runner + moe_runner = std::make_unique( + dtype_weights, useDeepSeekFp8, tile_N, use_shuffled_weight, + static_cast(weight_layout)); + } else { + // FP4 block scale MOE runner + moe_runner = std::make_unique( + dtype_act, dtype_weights, useDeepSeekFp8, tile_N, + static_cast(gated_act_type), + /*useShuffledMatrixA*/ true); + } + auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); + for (auto cfg : cfgs) { + valid_configs.push_back({tile_N, cfg}); + } + } + return valid_configs; } namespace trtllm_cubin_loader { diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 5f0e33ccf9..c91878ca0e 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -18,7 +18,6 @@ from enum import IntEnum from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, Union - import torch from ..autotuner import ( @@ -45,7 +44,6 @@ device_support_pdl, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices, - calculate_tile_tokens_dim, register_custom_op, register_fake_op, ) @@ -915,8 +913,9 @@ def __init__( use_deepseek_fp8: bool, hidden_size: int, intermediate_size: int, - gated_act_type: int, - tile_tokens_dim: Optional[int] = None, + gated_act_type: int = GatedActType.SwiGlu, + use_shuffled_weight: bool = False, + weight_layout: int = WeightLayout.MajorK, ): self.num_local_experts = num_local_experts self.top_k = top_k @@ -926,8 +925,18 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.gated_act_type = gated_act_type - self.tile_tokens_dim = tile_tokens_dim + self.gated_act_type = GatedActType(gated_act_type) + self.use_shuffled_weight = use_shuffled_weight + self.weight_layout = WeightLayout(weight_layout) + if ( + not self.use_shuffled_weight + or self.weight_layout != WeightLayout.MajorK + ): + assert ( + self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3 + ), ( + "use_shuffled_weight is False or weight_layout is not MajorK is only supported for FP8 block scale" + ) def get_valid_tactics( self, @@ -943,18 +952,8 @@ def get_valid_tactics( *extra_inputs, ) = inputs num_tokens = routing_logits.shape[0] - tile_tokens_dim = ( - calculate_tile_tokens_dim( - num_tokens, - self.num_local_experts, - self.top_k, - 64 if self.dtype_act == DtypeTrtllmGen.Bfloat16 else 128, - ) - if self.tile_tokens_dim is None - else self.tile_tokens_dim - ) + instance_key = ( - tile_tokens_dim, self.dtype_act, self.dtype_weights, self.use_deepseek_fp8, @@ -963,6 +962,8 @@ def get_valid_tactics( self.intermediate_size, self.num_local_experts, self.gated_act_type, + self.use_shuffled_weight, + self.weight_layout, num_tokens, ) if instance_key not in MoERunner.valid_tactics_dict: @@ -992,16 +993,6 @@ def forward( *extra_inputs, ) = inputs num_tokens = routing_logits.shape[0] - tile_tokens_dim = ( - calculate_tile_tokens_dim( - num_tokens, - self.num_local_experts, - self.top_k, - 64 if self.dtype_act == DtypeTrtllmGen.Bfloat16 else 128, - ) - if self.tile_tokens_dim is None - else self.tile_tokens_dim - ) extra_input_idx = 0 if trtllm_gen_dtype_has_scale(self.dtype_act): @@ -1026,42 +1017,106 @@ def forward( hidden_states_scale.dim() == 2 and hidden_states_scale.shape[0] == num_tokens ), "hidden_states_scale's first dimension must be batch size" - # TODO(siyuan): support fp8 - moe_op.trtllm_fp4_block_scale_moe( - routing_logits, - topk_ids, - expert_weights, - kwargs["routing_bias"], - hidden_states, - hidden_states_scale, # hidden_states_scale - kwargs["gemm1_weights"], - kwargs["gemm1_weights_scale"], - kwargs["gemm1_bias"], - kwargs["gemm1_alpha"], - kwargs["gemm1_beta"], - kwargs["gemm1_clamp_limit"], - kwargs["gemm2_weights"], - kwargs["gemm2_weights_scale"], - kwargs["gemm2_bias"], - kwargs["output1_scale_scalar"], - kwargs["output1_scale_gate_scalar"], - kwargs["output2_scale_scalar"], - kwargs["num_experts"], - self.top_k, - kwargs["n_group"], - kwargs["topk_group"], - self.intermediate_size, - kwargs["local_expert_offset"], - self.num_local_experts, - kwargs["routed_scaling_factor"], - tile_tokens_dim, - kwargs["routing_method_type"], - kwargs["enable_pdl"], - kwargs["do_finalize"], - self.gated_act_type, - output, - tactic, - ) + # Choose the appropriate operation based on data types + if ( + self.dtype_act == DtypeTrtllmGen.E4m3 + and self.dtype_weights == DtypeTrtllmGen.E4m3 + ): + # FP8 operations + if self.use_deepseek_fp8: + # FP8 block scale + current_num_tokens = hidden_states.shape[0] + current_hidden_size = hidden_states.shape[1] + current_hidden_states_scale = torch.full( + (current_hidden_size // 128, current_num_tokens), + 2.0, + dtype=torch.float, + device=hidden_states.device, + ) + moe_op.trtllm_fp8_block_scale_moe( + routing_logits, + kwargs["routing_bias"], + hidden_states, + current_hidden_states_scale, + kwargs["gemm1_weights"], + kwargs["gemm1_weights_scale"], + kwargs["gemm2_weights"], + kwargs["gemm2_weights_scale"], + output, + kwargs["num_experts"], + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + self.num_local_experts, + kwargs["routed_scaling_factor"], + kwargs["routing_method_type"], + kwargs["use_shuffled_weight"], + kwargs["weight_layout"], + kwargs["enable_pdl"], + [-1, -1] if tactic == -1 else tactic, + ) + else: + # FP8 per tensor scale + moe_op.trtllm_fp8_per_tensor_scale_moe( + routing_logits, + kwargs["routing_bias"], + hidden_states, + kwargs["gemm1_weights"], + kwargs["output1_scales_scalar"], + kwargs["output1_scales_gate_scalar"], + kwargs["gemm2_weights"], + kwargs["output2_scales_scalar"], + output, + kwargs["num_experts"], + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + self.num_local_experts, + kwargs["routed_scaling_factor"], + kwargs["use_routing_scales_on_input"], + kwargs["routing_method_type"], + kwargs["enable_pdl"], + [-1, -1] if tactic == -1 else tactic, + ) + else: + moe_op.trtllm_fp4_block_scale_moe( + routing_logits, + topk_ids, + expert_weights, + kwargs["routing_bias"], + hidden_states, + hidden_states_scale, # hidden_states_scale + kwargs["gemm1_weights"], + kwargs["gemm1_weights_scale"], + kwargs["gemm1_bias"], + kwargs["gemm1_alpha"], + kwargs["gemm1_beta"], + kwargs["gemm1_clamp_limit"], + kwargs["gemm2_weights"], + kwargs["gemm2_weights_scale"], + kwargs["gemm2_bias"], + kwargs["output1_scale_scalar"], + kwargs["output1_scale_gate_scalar"], + kwargs["output2_scale_scalar"], + kwargs["num_experts"], + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + self.num_local_experts, + kwargs["routed_scaling_factor"], + kwargs["routing_method_type"], + kwargs["enable_pdl"], + kwargs["do_finalize"], + self.gated_act_type, + output, + [-1, -1] if tactic == -1 else tactic, + ) @classmethod @functools.lru_cache(maxsize=None) @@ -1111,14 +1166,67 @@ def trtllm_fp8_per_tensor_scale_moe_op( local_num_experts: int, routed_scaling_factor: Optional[float], use_routing_scales_on_input: bool, - tile_tokens_dim: int = 8, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) + # Use AutoTuner to select the best tactic + tuner = AutoTuner.get() + MoERunner.refine_tuning_config(tune_max_num_tokens) + + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] + + # Create workspace buffers output = torch.empty( - hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device + num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device + ) + topk_ids = torch.empty( + num_tokens, top_k, dtype=torch.int32, device=hidden_states.device + ) + expert_weights = torch.empty( + num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device + ) + + dtype_act = DtypeTrtllmGen.E4m3 # FP8 activation + dtype_weights = DtypeTrtllmGen.E4m3 # FP8 weights + + moe_runner = MoERunner( + top_k=top_k, + num_local_experts=local_num_experts, + dtype_act=dtype_act, + dtype_weights=dtype_weights, + use_deepseek_fp8=False, # per_tensor mode + hidden_size=hidden_size, + intermediate_size=intermediate_size, + weight_layout=WeightLayout.MajorK, + use_shuffled_weight=True, + ) + + inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] + + _, tactic = tuner.choose_one( + "flashinfer::trtllm_fp8_per_tensor_scale_moe", + [moe_runner], + MoERunner.tuning_config_no_hidden_states_scales, # FP8 per-tensor doesn't use hidden_states_scale + inputs, + routing_bias=routing_bias, + gemm1_weights=gemm1_weights, + output1_scales_scalar=output1_scales_scalar, + output1_scales_gate_scalar=output1_scales_gate_scalar, + gemm2_weights=gemm2_weights, + output2_scales_scalar=output2_scales_scalar, + num_experts=num_experts, + n_group=n_group, + topk_group=topk_group, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=use_routing_scales_on_input, + routing_method_type=routing_method_type, + enable_pdl=enable_pdl, ) # Call the C++ function moe_op.trtllm_fp8_per_tensor_scale_moe( @@ -1140,9 +1248,9 @@ def trtllm_fp8_per_tensor_scale_moe_op( local_num_experts, routed_scaling_factor, use_routing_scales_on_input, - tile_tokens_dim, routing_method_type, enable_pdl, + [-1, -1] if tactic == -1 else tactic, ) return output @@ -1165,7 +1273,6 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( local_num_experts: int, routed_scaling_factor: Optional[float], use_routing_scales_on_input: bool, - tile_tokens_dim: int = 8, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, ): @@ -1196,15 +1303,78 @@ def trtllm_fp8_block_scale_moe_op( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: int, routing_method_type: int, use_shuffled_weight: bool = False, weight_layout: int = 0, enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) + # Use AutoTuner to select the best tactic - follow FP4 pattern exactly + tuner = AutoTuner.get() + MoERunner.refine_tuning_config(tune_max_num_tokens) + + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] + + # Create workspace buffers + output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device + ) + topk_ids = torch.empty( + num_tokens, top_k, dtype=torch.int32, device=hidden_states.device + ) + expert_weights = torch.empty( + num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device + ) + + dtype_act = DtypeTrtllmGen.E4m3 # FP8 activation + dtype_weights = DtypeTrtllmGen.E4m3 # FP8 weights + + moe_runner = MoERunner( + top_k=top_k, + num_local_experts=local_num_experts, + dtype_act=dtype_act, + dtype_weights=dtype_weights, + use_deepseek_fp8=True, # block_scale mode + hidden_size=hidden_size, + intermediate_size=intermediate_size, + weight_layout=weight_layout, + use_shuffled_weight=use_shuffled_weight, + ) + + inputs = [ + output, + routing_logits, + topk_ids, + expert_weights, + hidden_states, + hidden_states_scale, + ] + + _, tactic = tuner.choose_one( + "flashinfer::trtllm_fp8_block_scale_moe", + [moe_runner], + MoERunner.tuning_config_with_hidden_states_scales, # FP8 block-scale uses hidden_states_scale + inputs, + routing_bias=routing_bias, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + num_experts=num_experts, + n_group=n_group, + topk_group=topk_group, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=routing_method_type, + use_shuffled_weight=use_shuffled_weight, + weight_layout=weight_layout, + enable_pdl=enable_pdl, + ) # Call the C++ function for block scale MoE moe_op.trtllm_fp8_block_scale_moe( routing_logits, @@ -1224,11 +1394,11 @@ def trtllm_fp8_block_scale_moe_op( local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, use_shuffled_weight, weight_layout, enable_pdl, + [-1, -1] if tactic == -1 else tactic, ) return output @@ -1252,7 +1422,6 @@ def _fake_trtllm_fp8_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: int = 8, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, @@ -1294,7 +1463,6 @@ def trtllm_fp4_block_scale_moe_op( local_expert_offset: int, num_local_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int], routing_method_type: int, do_finalize: bool, enable_pdl: Optional[bool] = None, @@ -1340,13 +1508,6 @@ def trtllm_fp4_block_scale_moe_op( dtype_weights = deduce_trtllm_gen_tensor_dtype( gemm1_weights, gemm1_weights_scale ) - if tile_tokens_dim is None: - tile_tokens_dim = calculate_tile_tokens_dim( - num_tokens, - num_experts, - top_k, - max_tile_tokens_dim=64 if dtype_act == DtypeTrtllmGen.Bfloat16 else 128, - ) moe_runner = MoERunner( top_k=top_k, num_local_experts=num_local_experts, @@ -1356,9 +1517,8 @@ def trtllm_fp4_block_scale_moe_op( hidden_size=hidden_size, intermediate_size=intermediate_size, gated_act_type=gated_act_type, - # NOTE(siyuan): do not fix the tile_tokens_dim to let tunnable runner decide the tile_tokens_dim itself. - # however, when the user chooses a different heuristic for tile_tokens_dim, the autotuner will fail to find the correct cached tactics. - # tile_tokens_dim=tile_tokens_dim, + weight_layout=WeightLayout.MajorK, + use_shuffled_weight=True, ) tunning_config = ( MoERunner.tuning_config_no_hidden_states_scales @@ -1434,13 +1594,12 @@ def trtllm_fp4_block_scale_moe_op( local_expert_offset, num_local_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, do_finalize, enable_pdl, gated_act_type, output, - tactic, + [-1, -1] if tactic == -1 else tactic, ) if do_finalize: return [output] @@ -1480,7 +1639,6 @@ def _fake_trtllm_fp4_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int], routing_method_type: int, do_finalize: bool, enable_pdl: bool, @@ -1549,6 +1707,12 @@ def trtllm_fp8_per_tensor_scale_moe( Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ + if tile_tokens_dim is not None: + logger.warning_once( + "tile_tokens_dim in trtllm_fp8_per_tensor_scale_moe is planned for deprecation " + "in a future release. Please remove it from your code as tile_tokens_dim will no " + "longer be supported after v0.5.0." + ) return get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, @@ -1567,7 +1731,6 @@ def trtllm_fp8_per_tensor_scale_moe( local_num_experts, routed_scaling_factor, use_routing_scales_on_input, - tile_tokens_dim, routing_method_type, enable_pdl, ) @@ -1590,7 +1753,7 @@ def trtllm_fp8_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: int = 8, + tile_tokens_dim: Optional[int] = None, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, @@ -1621,6 +1784,12 @@ def trtllm_fp8_block_scale_moe( Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ + if tile_tokens_dim is not None: + logger.warning_once( + "tile_tokens_dim in trtllm_fp8_block_scale_moe is planned for deprecation " + "in a future release. Please remove it from your code as tile_tokens_dim will no " + "longer be supported after v0.5.0." + ) output = torch.empty( hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device ) @@ -1642,7 +1811,6 @@ def trtllm_fp8_block_scale_moe( local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, use_shuffled_weight, weight_layout, @@ -1675,7 +1843,7 @@ def trtllm_fp4_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int] = None, + tile_tokens_dim: Optional[int], routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, @@ -1726,7 +1894,7 @@ def trtllm_fp4_block_scale_moe( local_expert_offset (int): Offset of local experts in global expert space local_num_experts (int): Number of experts handled by this device routed_scaling_factor (Optional[float]): Scaling factor for routing (can be None for some routing methods) - tile_tokens_dim (int): Tile dimension for tokens (default: 8) + tile_tokens_dim (Optional[int]): Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type (int): Type of routing method to use (default: 0) - 0: Default (Softmax -> TopK) - 1: Renormalize (TopK -> Softmax) @@ -1745,6 +1913,12 @@ def trtllm_fp4_block_scale_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + if tile_tokens_dim is not None: + logger.warning_once( + "tile_tokens_dim in trtllm_fp4_block_scale_moe is planned for deprecation " + "in a future release. Please remove it from your code as tile_tokens_dim will no " + "longer be supported after v0.5.0." + ) return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( routing_logits, None, @@ -1772,7 +1946,6 @@ def trtllm_fp4_block_scale_moe( local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, do_finalize, enable_pdl, @@ -1807,7 +1980,7 @@ def trtllm_fp4_block_scale_routed_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int] = None, + tile_tokens_dim: Optional[int], routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, @@ -1860,7 +2033,7 @@ def trtllm_fp4_block_scale_routed_moe( local_expert_offset (int): Offset of local experts in global expert space local_num_experts (int): Number of experts handled by this device routed_scaling_factor (Optional[float]): Scaling factor for routing (can be None for some routing methods) - tile_tokens_dim (int): Tile dimension for tokens (default: 8) + tile_tokens_dim (Optional[int]): Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type (int): Type of routing method to use (default: 0) - 0: Default (Softmax -> TopK) - 1: Renormalize (TopK -> Softmax) @@ -1879,6 +2052,12 @@ def trtllm_fp4_block_scale_routed_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + if tile_tokens_dim is not None: + logger.warning_once( + "tile_tokens_dim in trtllm_fp4_block_scale_routed_moe is planned for deprecation " + "in a future release. Please remove it from your code as tile_tokens_dim will no " + "longer be supported after v0.5.0." + ) return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( None, topk_ids, @@ -1906,7 +2085,6 @@ def trtllm_fp4_block_scale_routed_moe( local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, do_finalize, enable_pdl, diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index e7dec73723..27034a4054 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,10 +1,11 @@ import dataclasses +import functools import logging import os from contextlib import nullcontext from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence, Union, Hashable import tvm_ffi from filelock import FileLock @@ -60,6 +61,33 @@ def __init__(self, name): ) ) + def debug_once(self, msg: str, *args: Hashable) -> None: + """ + As [`debug`][logging.Logger.debug], but subsequent calls with + the same message are silently dropped. + """ + self._print_once(self.debug, msg, *args) + + def info_once(self, msg: str, *args: Hashable) -> None: + """ + As [`info`][logging.Logger.info], but subsequent calls with + the same message are silently dropped. + """ + self._print_once(self.info, msg, *args) + + def warning_once(self, msg: str, *args: Hashable) -> None: + """ + As [`warning`][logging.Logger.warning], but subsequent calls with + the same message are silently dropped. + """ + self._print_once(self.warning, msg, *args) + + @functools.lru_cache(maxsize=None) + def _print_once(self, log_method, msg: str, *args: Hashable) -> None: + """Helper method to log messages only once per unique (msg, args) combination.""" + # Note: stacklevel=3 to show the caller's location, not this helper method + log_method(msg, *args, stacklevel=3) + logger = FlashInferJITLogger("flashinfer.jit") diff --git a/tests/conftest.py b/tests/conftest.py index dc81dc0db2..768eec8fa3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -137,11 +137,11 @@ def is_cuda_oom_error_str(e: str) -> bool: return "CUDA" in e and "out of memory" in e -@pytest.hookimpl(tryfirst=True) +@pytest.hookimpl(wrapper=True) def pytest_runtest_call(item): # skip OOM error and missing JIT cache errors try: - item.runtest() + yield except (torch.cuda.OutOfMemoryError, RuntimeError) as e: if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)): pytest.skip("Skipping due to OOM") diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index a093d4c0aa..df19e00310 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -14,10 +14,10 @@ limitations under the License. """ +import pytest from abc import ABC, abstractmethod from enum import IntEnum from typing import Dict -import pytest import torch from cuda.bindings import runtime from torch.nn import functional as F @@ -45,7 +45,7 @@ get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, ) -from flashinfer.utils import calculate_tile_tokens_dim, get_compute_capability +from flashinfer.utils import get_compute_capability def check_cuda(err): @@ -202,7 +202,7 @@ def _run_moe_computation(self, runtime_args): local_expert_offset=0, local_num_experts=self.config["num_experts"], routed_scaling_factor=self.config["routed_scaling"], - tile_tokens_dim=self.config["tile_tokens_dim"], + tile_tokens_dim=None, routing_method_type=self.config["routing_method_type"], gated_act_type=self.config["gated_act_type"], do_finalize=True, @@ -549,7 +549,6 @@ def call_moe( routed_scaling = kwargs["routed_scaling"] gated_act_type = kwargs["gated_act_type"] routing_method_type = kwargs["routing_method_type"] - tile_tokens_dim = kwargs["tile_tokens_dim"] # Create CUDA graph configuration config = { @@ -560,7 +559,6 @@ def call_moe( "top_k_groups": top_k_groups, "intermediate_size": intermediate_size, "routed_scaling": routed_scaling, - "tile_tokens_dim": tile_tokens_dim, "gated_act_type": gated_act_type, "routing_method_type": routing_method_type, } @@ -727,8 +725,8 @@ def prepare_static_weights_for_kernel( tmp_weights2 = convert_to_block_layout(tmp_weights2, block_k) gemm1_weights_fp8_shuffled.append(tmp_weights1) - gemm2_weights_fp8_shuffled.append(tmp_weights2) + kernel_gemm1_weights = torch.stack(gemm1_weights_fp8_shuffled).view( torch.float8_e4m3fn ) @@ -761,7 +759,6 @@ def call_moe( intermediate_size = kwargs["intermediate_size"] routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] - tile_tokens_dim = kwargs["tile_tokens_dim"] enable_pdl = kwargs.get("enable_pdl") hidden_states_scale = kwargs["hidden_states_scale"] hidden_states_quant = kwargs["hidden_states_quant"] @@ -772,29 +769,31 @@ def call_moe( "NaN detected in hidden_states_fp8" ) - output = trtllm_fp8_block_scale_moe( - expert_logits, - routing_bias, - hidden_states_fp8, - hidden_states_scale, - static_data["gemm1_weights"], - static_data["gemm1_scales"], - static_data["gemm2_weights"], - static_data["gemm2_scales"], - num_experts, - top_k, - n_groups, - top_k_groups, - intermediate_size, - 0, - num_experts, - routed_scaling, - tile_tokens_dim, - routing_method_type, - use_shuffled_weight=static_data["use_shuffled_weight"], - weight_layout=static_data["weight_layout"], - enable_pdl=enable_pdl, - ) + # Use autotuner for optimal kernel selection + with autotune(True): + output = trtllm_fp8_block_scale_moe( + expert_logits, + routing_bias, + hidden_states_fp8, + hidden_states_scale, + static_data["gemm1_weights"], + static_data["gemm1_scales"], + static_data["gemm2_weights"], + static_data["gemm2_scales"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + routed_scaling, + None, + routing_method_type, + use_shuffled_weight=static_data["use_shuffled_weight"], + weight_layout=static_data["weight_layout"], + enable_pdl=enable_pdl, + ) return output.to(torch.float) @@ -937,39 +936,40 @@ def call_moe( intermediate_size = kwargs["intermediate_size"] routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] - tile_tokens_dim = kwargs["tile_tokens_dim"] # Quantize to FP8 per-tensor using pre-computed global scale factor hidden_states_fp8, _ = quant_fp8_per_tensor( hidden_states_orig, hidden_states_scale_global ) - output = trtllm_fp8_per_tensor_scale_moe( - ( - expert_logits.to(torch.bfloat16) - if routing_method_type == RoutingMethodType.Llama4 - else expert_logits - ), - routing_bias, - hidden_states_fp8, - static_data["gemm1_weights"], - static_data["scale_c_fc1"], - static_data["scale_gate_fc1"], - static_data["gemm2_weights"], - static_data["scale_c_fc2"], - num_experts, - top_k, - n_groups, - top_k_groups, - intermediate_size, - 0, - num_experts, - routed_scaling, - routing_method_type - == RoutingMethodType.Llama4, # Use_routing_scales_on_input - tile_tokens_dim, - routing_method_type, - ) + # Use autotuner for optimal kernel selection + with autotune(True): + output = trtllm_fp8_per_tensor_scale_moe( + ( + expert_logits.to(torch.bfloat16) + if routing_method_type == RoutingMethodType.Llama4 + else expert_logits + ), + routing_bias, + hidden_states_fp8, + static_data["gemm1_weights"], + static_data["scale_c_fc1"], + static_data["scale_gate_fc1"], + static_data["gemm2_weights"], + static_data["scale_c_fc2"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + routed_scaling, + routing_method_type + == RoutingMethodType.Llama4, # Use_routing_scales_on_input + None, + routing_method_type, + ) return output.to(torch.float) @@ -985,8 +985,6 @@ def get_tolerances(self): # ==================================================================================== # Quantizer Factory # ==================================================================================== - - def get_moe_impl(quant_mode: QuantMode): """Factory function to get the appropriate MoE implementation.""" if quant_mode == QuantMode.FP8_BLOCK_SCALE: @@ -1815,7 +1813,6 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): "intermediate_size": args.intermediate_size, "routed_scaling": kwargs["routed_scaling"], "routing_method_type": kwargs["routing_method_type"], - "tile_tokens_dim": kwargs["tile_tokens_dim"], "do_finalize": True, "gated_act_type": args.gated_act_type, "hidden_states_scale": args.hidden_states_scale, @@ -1837,203 +1834,16 @@ def cache_permute_indices(): return _cache_permute_indices -@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) -@pytest.mark.parametrize("hidden_size", [1024, 8192]) -@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) -@pytest.mark.parametrize( - "moe_impl", - [ - pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), - pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), - pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), - pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), - pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), - ], -) -@pytest.mark.parametrize( - "routing_config", - [ - pytest.param( - { - "num_experts": 384, - "top_k": 8, - "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [ - FP4Moe, - FP8BlockScaleMoe, - ], - }, - id="kimi_k2", - ), - pytest.param( - { - "num_experts": 256, - "top_k": 8, - "padding": 8, - "n_groups": 8, - "top_k_groups": 4, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [ - FP4Moe, - FP8BlockScaleMoe, - ], - }, - id="DSv3", - ), - pytest.param( - { - "num_experts": 72, - "top_k": 6, - "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [ - FP4Moe, - FP8BlockScaleMoe, - ], - }, - id="DSLite", - ), - pytest.param( - { - "num_experts": 256, - "top_k": 8, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP8PerTensorMoe, FP4Moe], - }, - id="Renorm", - marks=pytest.mark.skip( - reason="Disabled for testing speed - similar to RenormalizeNaive" - ), - ), - pytest.param( - { - "num_experts": 128, - "top_k": 10, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], - }, - id="Qwen3_next", - ), - pytest.param( - { - "num_experts": 128, - "top_k": 8, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.RenormalizeNaive, - "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], - }, - id="RenormNaive", - ), - pytest.param( - { - "num_experts": 16, - "top_k": 2, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.TopK, - "compatible_moe_impls": [FP4Moe], - }, - id="TopK", - ), - pytest.param( - { - "num_experts": 128, - "top_k": 1, - "padding": 8, - "n_groups": 0, - "top_k_groups": 0, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.Llama4, - "compatible_moe_impls": [FP8PerTensorMoe], - }, - id="Llama4", - ), - ], -) -@pytest.mark.parametrize( - "weight_processing", - [ - pytest.param( - { - "use_shuffled_weight": False, - "layout": WeightLayout.MajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], - }, - id="NoShuffle_MajorK", - ), - pytest.param( - { - "use_shuffled_weight": True, - "layout": WeightLayout.MajorK, - "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], - }, - id="Shuffled_MajorK", - ), - pytest.param( - { - "use_shuffled_weight": True, - "layout": WeightLayout.BlockMajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], - }, - id="Shuffled_BlockMajorK", - ), - ], -) -@pytest.mark.parametrize( - "gated_act_type", - [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), - pytest.param(GatedActType.GeGlu, id="GeGlu"), - ], -) -def test_moe_quantization_classes( - num_tokens, - hidden_size, - intermediate_size, +def skip_checks( moe_impl, routing_config, weight_processing, gated_act_type, - cache_permute_indices, + num_tokens, + hidden_size, + intermediate_size, ): - """ - Test MoE implementations using separated quantization workflow. - - This test demonstrates the clean separation between: - - Static weight quantization (done offline) - - Dynamic input quantization (done at runtime) - - Each quantization class clearly shows which precision is being used. - """ + """Common skip logic for all tests.""" compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] not in [10]: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") @@ -2044,14 +1854,12 @@ def test_moe_quantization_classes( or routing_config["routing_method_type"] != RoutingMethodType.TopK or num_tokens > 128 ): - # GeGlu is only supported for FP4Moe FP4_NVFP4_NVFP4 and TopK routing pytest.skip( f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" ) elif gated_act_type == GatedActType.SwiGlu and ( hidden_size > 1024 or intermediate_size > 1024 ): - # Skip some tests for SwiGlu for testing speed pytest.skip( f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" ) @@ -2070,6 +1878,10 @@ def test_moe_quantization_classes( pytest.skip( f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}" ) + if intermediate_size not in routing_config["compatible_intermediate_size"]: + pytest.skip( + f"Incompatible: intermediate_size={intermediate_size} with {routing_config['routing_method_type'].name} routing ({routing_config['num_experts']} experts)" + ) # TODO(jimmzhou): enable MxFP4xBf16 on SM103 if ( @@ -2082,6 +1894,30 @@ def test_moe_quantization_classes( "Note(jimmzhou): Make MxFP4xBf16 nonfunctional on SM103 to avoid B200 regression" ) + +def run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Common test logic for all routing methods.""" + skip_checks( + moe_impl, + routing_config, + weight_processing, + gated_act_type, + num_tokens, + hidden_size, + intermediate_size, + ) + + torch.cuda.synchronize() + moe_impl._cache_permute_indices = cache_permute_indices seed = 0 @@ -2096,17 +1932,6 @@ def test_moe_quantization_classes( num_experts = routing_config["num_experts"] routing_method_type = routing_config["routing_method_type"] - tile_tokens_dim = calculate_tile_tokens_dim( - num_tokens, - num_experts, - top_k, - max_tile_tokens_dim=128 - if ( - type(moe_impl) is FP4Moe and moe_impl.quant_mode != QuantMode.FP4_MXFP4_Bf16 - ) - else 64, - ) - # Validation checks assert top_k <= num_experts assert top_k <= 10 @@ -2117,15 +1942,12 @@ def test_moe_quantization_classes( assert num_experts % 4 == 0 assert top_k < (top_k_groups * num_experts / n_groups) - # Create test data based on routing method and quantization mode - # Different kernels have different dtype requirements for routing logits + # Create test data based on routing method if routing_method_type == RoutingMethodType.DeepSeekV3: - # DeepSeekV3 uses float for routing logits expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( torch.float ) else: - # Other routing methods (Renormalize, RenormalizeNaive, Llama4) use bfloat16 expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( torch.bfloat16 ) @@ -2191,12 +2013,12 @@ def test_moe_quantization_classes( f"Routing method {routing_method_type} not implemented" ) - # 1. Quantize weights offline (static, done once) + compute global scale factors + # 1. Quantize weights offline weights_data = moe_impl.quantize_weights( gemm1_weights, gemm2_weights, hidden_states ) - # 2. Quantize inputs at runtime (dynamic, done per inference) using pre-computed scales + # 2. Quantize inputs at runtime inputs_data = moe_impl.quantize_inputs( hidden_states, weights_data["hidden_states_scale_global"] ) @@ -2227,14 +2049,13 @@ def test_moe_quantization_classes( gated_act_type, ) - # Compute reference output using the moe_impl + # Compute reference output output_dequant_reference, args_dequant = moe_impl.compute_reference(args) - # Validate that reference computation succeeded if output_dequant_reference is None: pytest.fail("Reference computation failed to produce output") - # Compute actual output using the moe_impl + # Compute actual output output_dequant_actual = moe_impl.compute_production( args_dequant, args, @@ -2247,15 +2068,12 @@ def test_moe_quantization_classes( top_k_groups=top_k_groups, routed_scaling=routed_scaling, routing_method_type=routing_method_type, - tile_tokens_dim=tile_tokens_dim, weight_processing=weight_processing, enable_pdl=True, - hidden_states_quant=inputs_data[ - "hidden_states" - ], # NOTE(yingyi): only for fp8 block scale for now, refactor later + hidden_states_quant=inputs_data["hidden_states"], ) - # Compare outputs using moe_impl-specific tolerances + # Compare outputs tolerances = moe_impl.get_tolerances() check_accuracy( output_dequant_reference, @@ -2264,3 +2082,363 @@ def test_moe_quantization_classes( rtol=tolerances["rtol"], percent=tolerances["percent"], ) + + +# Test: DeepSeekV3 routing +@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), + pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 384, + "top_k": 8, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [512, 1024, 2048], + }, + id="kimi_k2", + ), + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [512, 1024, 2048], + }, + id="DSv3", + ), + pytest.param( + { + "num_experts": 72, + "top_k": 6, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [384, 768], + }, + id="DSLite", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="NoShuffle_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.BlockMajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="Shuffled_BlockMajorK", + ), + ], +) +@pytest.mark.parametrize( + "gated_act_type", + [ + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(GatedActType.GeGlu, id="GeGlu"), + ], +) +def test_deepseekv3_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Test DeepSeekV3 routing configurations.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, + ) + + +# Test: Renormalize routing +@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), + pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], + "compatible_intermediate_size": [384, 768, 1024, 2048], + }, + id="Renorm", + marks=pytest.mark.skip(reason="Skip temporary"), + ), + pytest.param( + { + "num_experts": 512, + "top_k": 10, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], + "compatible_intermediate_size": [512], + }, + id="Qwen3_next", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + ], +) +@pytest.mark.parametrize( + "gated_act_type", + [ + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(GatedActType.GeGlu, id="GeGlu"), + ], +) +def test_renormalize_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Test Renormalize routing configurations.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, + ) + + +# Test: TopK routing +@pytest.mark.parametrize("num_tokens", [1, 8, 128]) # Limited for GeGlu +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [384, 512, 768, 1024]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 16, + "top_k": 2, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.TopK, + "compatible_moe_impls": [FP4Moe], + "compatible_intermediate_size": [384, 512, 768, 1024], + }, + id="TopK", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + ], +) +@pytest.mark.parametrize( + "gated_act_type", + [ + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(GatedActType.GeGlu, id="GeGlu"), + ], +) +def test_topk_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Test TopK routing configuration.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, + ) + + +# Test: Llama4 routing +@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 128, + "top_k": 1, + "padding": 8, + "n_groups": 0, + "top_k_groups": 0, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.Llama4, + "compatible_moe_impls": [FP8PerTensorMoe], + "compatible_intermediate_size": [1024, 2048], + }, + id="Llama4", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + ], +) +@pytest.mark.parametrize( + "gated_act_type", + [ + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + ], +) +def test_llama4_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Test Llama4 routing configuration with FP8 per-tensor.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, + ) From 6a962ef78507c14f95cb0b0c33e08edae65732d2 Mon Sep 17 00:00:00 2001 From: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Date: Thu, 30 Oct 2025 05:45:00 +0800 Subject: [PATCH 006/130] Fix trtllm-gen attention illegal memory access (#2002) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This PR fixes illegal memory access of trtllm-gen attention kernels. It changes the workspace buffer from `int_workspace_buffer` to `float_workspace_buffer`. `int_workspace_buffer` is a fixed sized buffer and not initialized to zero, which should not be used. ## ๐Ÿ” Related Issues Issue #1928 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Fixed memory allocation in the decode module to improve computation accuracy and stability during text generation. --- flashinfer/decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 467152af38..45bc2c58ad 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1988,7 +1988,7 @@ def paged_run( q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect paged_k_cache, paged_v_cache, - int_workspace_buffer, + float_workspace_buffer, block_tables, kv_lens_buffer, max_kv_len, From 3c079216652b5706d30699121045c93769129900 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Wed, 29 Oct 2025 20:26:06 -0700 Subject: [PATCH 007/130] release: Bump version for v0.5.0rc1 release; (#2008) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Update version in `version.txt` to v0.5.0 as we prepare for v0.5.0rc1 release. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Version bump to 0.5.0 (no functional changes) --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 267577d47e..8f0916f768 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.1 +0.5.0 From b9287c91049d78b3fbfa1f8172bcef3d8a56e044 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 30 Oct 2025 09:42:49 -0700 Subject: [PATCH 008/130] bugfix: fix regex in update wheel index script (#2009) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description The regex cannot recognize release candidates (`v0.5.0rc1`) or post releases (`v1.2.3.post1`): https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551 This PR fixes the issue. ## ๐Ÿ” Related Issues https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Enhanced version string parsing in the wheel package indexing process to support more complex version formats, including pre-release, post-release, and development versions, ensuring compatibility with PEP 440 versioning standards. --- scripts/update_whl_index.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/scripts/update_whl_index.py b/scripts/update_whl_index.py index 474ec61ea9..cb9ed3d183 100644 --- a/scripts/update_whl_index.py +++ b/scripts/update_whl_index.py @@ -31,7 +31,11 @@ def get_package_info(wheel_path: pathlib.Path) -> Optional[dict]: wheel_name = wheel_path.name # Try flashinfer-python pattern - match = re.match(r"flashinfer_python-([0-9.]+(?:\.dev\d+)?)-", wheel_name) + # Supports PEP 440: base_version[{a|b|rc}N][.postN][.devN] + match = re.match( + r"flashinfer_python-([0-9.]+(?:(?:a|b|rc)\d+)?(?:\.post\d+)?(?:\.dev\d+)?)-", + wheel_name, + ) if match: version = match.group(1) return { @@ -41,7 +45,11 @@ def get_package_info(wheel_path: pathlib.Path) -> Optional[dict]: } # Try flashinfer-cubin pattern - match = re.match(r"flashinfer_cubin-([0-9.]+(?:\.dev\d+)?)-", wheel_name) + # Supports PEP 440: base_version[{a|b|rc}N][.postN][.devN] + match = re.match( + r"flashinfer_cubin-([0-9.]+(?:(?:a|b|rc)\d+)?(?:\.post\d+)?(?:\.dev\d+)?)-", + wheel_name, + ) if match: version = match.group(1) return { @@ -51,7 +59,11 @@ def get_package_info(wheel_path: pathlib.Path) -> Optional[dict]: } # Try flashinfer-jit-cache pattern (has CUDA suffix in version) - match = re.match(r"flashinfer_jit_cache-([0-9.]+(?:\.dev\d+)?\+cu\d+)-", wheel_name) + # Supports PEP 440: base_version[{a|b|rc}N][.postN][.devN]+cuXXX + match = re.match( + r"flashinfer_jit_cache-([0-9.]+(?:(?:a|b|rc)\d+)?(?:\.post\d+)?(?:\.dev\d+)?\+cu\d+)-", + wheel_name, + ) if match: version = match.group(1) cuda_ver = get_cuda_version(wheel_name) From a5ff03391852c569ddb7d7d0c5ac3af490fdc124 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Thu, 30 Oct 2025 14:23:36 -0700 Subject: [PATCH 009/130] fix: Enable SM121 for mm_fp4 (#2012) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description In #1809 we previously added a compute-capability-based support check for `mm_fp4`. However, we missed enabling SM121 for backend = `cudnn` and `cutlass`. Additionally, we marked `trtllm` as supported on SM120 when it is not. Current PR fixes it. Example benchmark and pytest command on SM121 after the fix ``` (py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( [PERF] cudnn :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec [PERF] cutlass :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec (py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py ====================================================================================================================== test session starts ====================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items ... ======================================================================================================================= warnings summary ======================================================================================================================== ../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285 /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ========================================================================================================== ``` ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Expanded hardware compatibility by adding support for newer NVIDIA GPU architectures. * FP4 quantized operations now available across multiple backends on supported devices. --- benchmarks/routines/flashinfer_benchmark_utils.py | 1 + flashinfer/gemm.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index fa1a527d17..3836e03630 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -241,6 +241,7 @@ def dtype_str_to_torch_dtype(dtype_str): "10.0": ["cudnn", "trtllm", "cutlass"], "10.3": ["cudnn", "trtllm", "cutlass"], "12.0": ["cudnn", "cutlass"], + "12.1": ["cudnn", "cutlass"], }, # MOE "trtllm_fp4_block_scale_moe": { diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 63a2f7e211..8f0d16a015 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -1750,7 +1750,7 @@ def _check_mm_fp4_problem_size( return True -@supported_compute_capability([100, 103, 110, 120]) +@supported_compute_capability([100, 103, 110, 120, 121]) def _cudnn_gemm_fp4_requirement( a: torch.Tensor, b: torch.Tensor, @@ -1808,7 +1808,7 @@ def _cudnn_gemm_fp4_requirement( return True -@supported_compute_capability([100, 103, 120]) +@supported_compute_capability([100, 103]) def _trtllm_gemm_fp4_requirement( a: torch.Tensor, b: torch.Tensor, @@ -1830,7 +1830,7 @@ def _trtllm_gemm_fp4_requirement( return True -@supported_compute_capability([100, 103, 120]) +@supported_compute_capability([100, 103, 120, 121]) def _cutlass_gemm_fp4_requirement( a: torch.Tensor, b: torch.Tensor, From de4c70173c0d7bb1a23821e7ee4f30d132f0cf62 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 30 Oct 2025 23:28:22 -0700 Subject: [PATCH 010/130] fix: ensure SM120/121 SFA/SFB contiguity (#1963) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Fix the regression in vLLM and SGLang with FI 0.4.0 in bmm_fp8 ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc: @yzh119 ## Summary by CodeRabbit * **Bug Fixes** * Fixed memory layout handling for tensor operations in GPU computations to ensure proper alignment, improving stability and performance. --- flashinfer/gemm.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 8f0d16a015..b561a67862 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -259,8 +259,10 @@ def forward( scale_k_count = ( k_dim + scale_gran_k - 1 ) // scale_gran_k # k dimension - scale_a_expanded = scale_a.view(1, 1).expand( - scale_m_count, scale_k_count + scale_a_expanded = ( + scale_a.view(1, 1) + .expand(scale_m_count, scale_k_count) + .contiguous() ) else: scale_a_expanded = scale_a @@ -273,8 +275,10 @@ def forward( scale_k_count = ( k_dim + scale_gran_k - 1 ) // scale_gran_k # k dimension - scale_b_expanded = scale_b.view(1, 1).expand( - scale_n_count, scale_k_count + scale_b_expanded = ( + scale_b.view(1, 1) + .expand(scale_n_count, scale_k_count) + .contiguous() ) else: scale_b_expanded = scale_b From 1181c5d8ee2ddb7d07ac1f7d12cf95e1fab076fa Mon Sep 17 00:00:00 2001 From: Wenxuan Tan Date: Fri, 31 Oct 2025 01:48:56 -0500 Subject: [PATCH 011/130] More realistic bench for POD Attn (#2013) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Use real head sizes, seq lens and add comparison with sequential prefill + decode. Results on H100 (without overlap, which only adds ~150GB/s for persistent): image cc @yzh119 ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit ## Release Notes * **New Features** * Added comprehensive performance benchmarking for batch attention operations with detailed timing measurements. * Introduced sequential dual-kernel benchmark path with extended memory bandwidth reporting. * **Tests** * Updated benchmark test configurations to use deterministic, fixed values for improved reproducibility. * Adjusted benchmark parameters for consistency across test iterations. --- benchmarks/bench_mixed_attention.py | 131 +++++++++++++++------------- 1 file changed, 69 insertions(+), 62 deletions(-) diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index 85753a71f9..9773e8f37d 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -72,6 +72,24 @@ def run_bench( measurements = bench_gpu_time(lambda: wrapper_old.run(q, kv_data)) ms_old = np.median(measurements) + wrapper_persistent = flashinfer.BatchAttention(kv_layout="NHD") + wrapper_persistent.plan( + q_indptr.to(device), + kv_indptr.to(device), + torch.arange(num_blocks, dtype=torch.int32, device=device), + seq_lens.to(device), + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + page_block_size, + causal=causal, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + o_persistent, _ = wrapper_persistent.run(q, kv_data) + measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data)) + ms_persistent = np.mean(measurements_persistent) if len(p_kv_lens) == 1: q_d = q[: d_q_indptr[-1]] kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) @@ -123,9 +141,46 @@ def run_bench( ) ) ms_pod = np.median(measurements) + + # Sequential two kernels: single prefill + batch decode (tensor cores) + # Prefill using single_prefill_with_kv_cache + def _run_single_prefill(): + return flashinfer.prefill.single_prefill_with_kv_cache( + q_p, + k_p, + v_p, + causal=causal, + pos_encoding_mode="NONE", + backend="fa2", + ) + + measurements_prefill = bench_gpu_time(lambda: _run_single_prefill()) + ms_prefill = np.median(measurements_prefill) + + # Batch decode using tensor cores + wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True + ) + wrapper_decode.plan( + d_kv_indptr.to(device), + kv_indices_d.to(device), + last_page_len_d, + num_qo_heads, + num_kv_heads, + head_dim, + page_block_size, + data_type=torch.bfloat16, + q_data_type=torch.bfloat16, + ) + measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d)) + ms_decode = np.median(measurements_decode) + ms_seq_two_kernels = ms_prefill + ms_decode + print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") if len(p_kv_lens) == 1: print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") + print(f"Elapsed time (Sequential two kernels): {ms_seq_two_kernels:.2f} ms") + print(f"Elapsed time (Persistent BatchAttention): {ms_persistent:.2f} ms") total_bytes = ( q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() ) @@ -137,6 +192,14 @@ def run_bench( if len(p_kv_lens) == 1: bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3) print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s") + bandwidth_seq_gb_s = total_bytes / (ms_seq_two_kernels * 1e-3) / (1024**3) + print( + f"Memory bandwidth (Sequential two kernels): {bandwidth_seq_gb_s:.2f} GB/s" + ) + bandwidth_persistent_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3) + print( + f"Memory bandwidth (Persistent BatchAttention): {bandwidth_persistent_gb_s:.2f} GB/s" + ) if __name__ == "__main__": @@ -144,70 +207,14 @@ def run_bench( torch.random.manual_seed(42) # Irregular sequence lengths for prefill and decode - d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256] - d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256] - p_q_configs = [[17] * 1, [10000], [17] * 1, []] - p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []] - - # construct random length testcases - for _ in range(1): - bsz = 256 - stride = 16 - sparsity = 0.05 - - full_kv_len = np.random.randint(1000, 8192, size=bsz) - p_q_lens = [] - p_kv_lens = [] - d_q_lens = [] - d_kv_lens = [] - for i in range(bsz): - if i % stride == 0: - kv_len = full_kv_len[i] - qo_len = stride + 1 - p_q_lens.append(qo_len) - p_kv_lens.append(kv_len) - else: - kv_len = int(full_kv_len[i] * sparsity) - qo_len = 1 - d_q_lens.append(qo_len) - d_kv_lens.append(kv_len) - - p_q_configs.append(p_q_lens) - p_kv_configs.append(p_kv_lens) - d_q_len_configs.append(d_q_lens) - d_kv_len_configs.append(d_kv_lens) - - for _ in range(1): - bsz = 128 - stride = 16 - sparsity = 0.05 - - full_kv_len = np.random.randint(2000, 16000, size=bsz) - p_q_lens = [] - p_kv_lens = [] - d_q_lens = [] - d_kv_lens = [] - - for i in range(bsz): - if i % stride == 0: - kv_len = full_kv_len[i] - qo_len = stride + 1 - p_q_lens.append(qo_len) - p_kv_lens.append(kv_len) - else: - kv_len = int(full_kv_len[i] * sparsity) - qo_len = 1 - d_q_lens.append(qo_len) - d_kv_lens.append(kv_len) - - p_q_configs.append(p_q_lens) - p_kv_configs.append(p_kv_lens) - d_q_len_configs.append(d_q_lens) - d_kv_len_configs.append(d_kv_lens) + d_q_len_configs = [[1] * 128, [1] * 128, [1] * 128, [1] * 128] + d_kv_len_configs = [[2048] * 128, [4096] * 128, [8192] * 128, [8192] * 128] + p_q_configs = [[2048], [4096], [4096], [6000]] + p_kv_configs = [[2048], [4096], [4096], [7000]] page_block_size = 1 - num_kv_heads = 4 - num_qo_heads = 28 + num_kv_heads = 8 + num_qo_heads = 32 head_dim = 128 for idx, (p_q_lens, p_kv_lens, d_q_len, d_kv_len) in enumerate( From f9cd0345a162f4b19d62a0918ba027a3c59917a7 Mon Sep 17 00:00:00 2001 From: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com> Date: Fri, 31 Oct 2025 18:43:12 +0200 Subject: [PATCH 012/130] Feature: Support non-gated activation in cutlass fused MoE nvfp4 (#2011) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This PR removes an assertion in the cutlass fused moe bindings to enable non-gated activations in nvfp4. It also adds a test for this path with relu2 activation. ## ๐Ÿ” Related Issues N/A ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [v] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [v] I have installed the hooks with `pre-commit install`. - [v] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [v] Tests have been added or updated as needed. - [v] All tests are passing (`unittest`, etc.). ## Reviewer Notes N/A ## Summary by CodeRabbit * **New Features** * Enhanced quantized Mixture of Experts models to support configurable activation types (Swiglu and ReLU2) in the NVFP4 quantization path. * Improved parameter handling to correctly adapt weight shapes and quantization settings based on the selected activation type. --------- Signed-off-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com> --- ...shinfer_cutlass_fused_moe_sm100_binding.cu | 55 +++++++++++------- tests/moe/test_trtllm_cutlass_fused_moe.py | 57 +++++++++++++------ 2 files changed, 76 insertions(+), 36 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu index 5bef6f8719..267e4591cc 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu @@ -361,8 +361,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { num_rows, hidden_size, inter_size, num_experts_total, static_cast(experts_per_token), base_activation_type, parallelism_config, min_latency_mode); - auto const quant_params = - getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); + auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, + quant_scales, base_activation_type); kernels::MoeMinLatencyParams min_latency_params{}; // TODO: support lora in the future @@ -542,8 +542,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { num_rows, hidden_size, inter_size, num_experts_total, static_cast(experts_per_token), base_activation_type, parallelism_config, min_latency_mode); - auto const quant_params = - getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); + auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, + quant_scales, base_activation_type); // TODO: support lora in the future ::tensorrt_llm::kernels::LoraParams lora_params{}; @@ -809,9 +809,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { return info; } - kernels::QuantParams getQuantParams(int64_t num_experts_on_rank, int64_t hidden_size, - int64_t inter_size, - Optional> quant_scales) const { + kernels::QuantParams getQuantParams( + int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size, + Optional> quant_scales, + ActivationType base_activation_type = ActivationType::Swiglu) const { if (isFp8Quant()) { TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization"; TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) @@ -1013,18 +1014,34 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // Check shapes TVM_FFI_ICHECK(fc1_act_global.ndim() == 0 || fc1_act_global.size(0) == num_experts_on_rank) << "fc1 act global must be scalar or (num_experts_on_rank,)"; - TVM_FFI_ICHECK( - fc1_weight_block.size(0) == num_experts_on_rank && - fc1_weight_block.size(1) == - TmaWarpSpecializedGroupedGemmInput::alignToSfDim( - inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) * - 2 && - fc1_weight_block.size(2) * FP8_PER_INT32 * - TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize == - TmaWarpSpecializedGroupedGemmInput::alignToSfDim( - hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4)) - << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " - "// block_scale_vector_size)"; + if (isGatedActivation(base_activation_type)) { + TVM_FFI_ICHECK( + fc1_weight_block.size(0) == num_experts_on_rank && + fc1_weight_block.size(1) == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) * + 2 && + fc1_weight_block.size(2) * FP8_PER_INT32 * + TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4)) + << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // " + "4 " + "// block_scale_vector_size)"; + } else { + TVM_FFI_ICHECK( + fc1_weight_block.size(0) == num_experts_on_rank && + fc1_weight_block.size(1) == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) && + fc1_weight_block.size(2) * FP8_PER_INT32 * + TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4)) + << "fc1 weight block size must be (num_experts_on_rank, inter_size, hidden_size // 4 " + "// block_scale_vector_size)"; + } + TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank) << "fc1 global size must be (num_experts_on_rank,)"; TVM_FFI_ICHECK(fc2_act_global.ndim() == 0 || fc2_act_global.size(0) == num_experts_on_rank) diff --git a/tests/moe/test_trtllm_cutlass_fused_moe.py b/tests/moe/test_trtllm_cutlass_fused_moe.py index b9b79a4028..bae12ab070 100644 --- a/tests/moe/test_trtllm_cutlass_fused_moe.py +++ b/tests/moe/test_trtllm_cutlass_fused_moe.py @@ -17,6 +17,7 @@ from contextlib import nullcontext import pytest +from flashinfer.fused_moe.core import ActivationType import torch from torch.nn import functional as F @@ -137,7 +138,7 @@ def compute_routing( return routing_weights, selected_experts -def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids): +def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids, activation_type): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) @@ -147,13 +148,26 @@ def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids): topk_ids = topk_ids.view(-1) # w1 needs to be swapped in terms of gate and up_proj + if activation_type == ActivationType.Swiglu: + + def act(weight, mask): + m = weight.shape[0] + assert m % 2 == 0 + w1_expert, w3_expert = weight[m // 2 :, :], weight[: m // 2, :] + return F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t()) + + elif activation_type == ActivationType.Relu2: + + def act(weight, mask): + return F.relu(a[mask] @ weight.t()) ** 2 + + else: + raise ValueError(f"Unsupported activation type {activation_type}") + for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - m = w1[i].shape[0] - assert m % 2 == 0 - w1_expert, w3_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :] - inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t()) + inter = act(w1[i], mask) inter_gs = torch.tensor(1.0).cuda() inter_q, inter_blockscale = fp4_quantize(inter, inter_gs) inter = dequantize_nvfp4_to_dtype( @@ -363,6 +377,11 @@ def test_moe_fp8( [(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)], ) @pytest.mark.parametrize("quantized_input", [False, True]) +@pytest.mark.parametrize( + "activation_type", + [ActivationType.Swiglu, ActivationType.Relu2], + ids=["swiglu", "relu2"], +) @pytest.mark.skipif( torch.cuda.get_device_capability()[0] not in [10, 11, 12], reason="NVFP4 is only supported on SM100, SM110 and SM120", @@ -376,6 +395,7 @@ def test_moe_nvfp4( otype, wtype, quantized_input, + activation_type, ): # Skip invalid configurations if top_k > num_experts: @@ -391,10 +411,10 @@ def test_moe_nvfp4( n = intermediate_size k = hidden_size - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 - w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous() + w1_n = 2 * n if activation_type == ActivationType.Swiglu else n + w1 = torch.randn((e, w1_n, k), device="cuda", dtype=otype) / 10 - sf_w1_2n = round_up(2 * n, 128) + sf_w1_2n = round_up(w1_n, 128) sf_w1_k = round_up(k // quant_blocksize, 4) w1_blockscale = torch.empty( (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn @@ -409,8 +429,8 @@ def test_moe_nvfp4( w2_blockscale = torch.empty( (e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn ) - w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) - w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) + w1_q = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8) + w1_q_cutlass = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8) w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32) w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32) @@ -424,7 +444,7 @@ def test_moe_nvfp4( w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert]) w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize( - w1_cutlass[expert], w1_gs[expert] + w1[expert], w1_gs[expert] ) w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert]) @@ -469,6 +489,7 @@ def test_moe_nvfp4( quant_scales=quant_scales, input_sf=input_sf, output=flash_output, + activation_type=activation_type, ) # Ref check @@ -483,7 +504,7 @@ def test_moe_nvfp4( block_size=quant_blocksize, ) - w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=otype) + w1_d = torch.empty((e, w1_n, k), device="cuda", dtype=otype) w2_d = torch.empty((e, k, n), device="cuda", dtype=otype) for idx in range(0, e): @@ -504,12 +525,14 @@ def test_moe_nvfp4( block_size=quant_blocksize, ) - w1_q_cutlass = torch.cat((w1_q[:, n:, :], w1_q[:, :n, :]), dim=1).contiguous() - w1_blockscale_cutlass = torch.cat( - (w1_blockscale[:, n:, :], w1_blockscale[:, :n, :]), dim=1 - ).contiguous() ref_output = torch_moe_nvfp4( - a_in_dtype, w1_d, w2_d, top_k, routing_weights, selected_experts + a_in_dtype, + w1_d, + w2_d, + top_k, + routing_weights, + selected_experts, + activation_type, ) torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1) From 5854494a0187795800862c73bad7c727231ab60f Mon Sep 17 00:00:00 2001 From: qsang-nv <200703406+qsang-nv@users.noreply.github.com> Date: Sun, 2 Nov 2025 09:06:41 +0800 Subject: [PATCH 013/130] feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend (#2001) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Expose xqa backend to trtllm attention interface, and improve layout coverage of trtllm-gen and xqa backends. Now both trtllm-gen/xqa supports NHD/HND kv-cache layout. * support NHD layout for trtllm-gen * refactor xqa (https://github.com/flashinfer-ai/flashinfer/commit/869c0c1c6bc199f82f30c23ab78a1b4aa9a1bd3a) * allow user passed stride_page/head/token * support both HND and NHD * remove macros such as PAGED_KV_CACHE_LAYOUT and USE_PAGED_KV_CACHE * adding unittests for both trtllm-gen/xqa on NHD/HND * adding unified API for trtllm-gen/xqa, and unified unittest ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Added xqa-based batch decode API and public kv_layout option (NHD/HND); added enable_pdl toggle to inference wrappers. * **Improvements** * Automatic backend selection for decoding, consistent KV-layout normalization across paths, and unified stride-aware paged-KV handling with layout-aware shapes, scales, and workspace handling. * **Tests** * Expanded tests to cover both KV layouts, enable_pdl, new batch-decode workflows, backend/layout permutations, and fp8/mixed-dtype scenarios. --------- Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> Co-authored-by: yzh119 Co-authored-by: Zihao Ye --- csrc/flashinfer_xqa_binding.cu | 25 +- csrc/trtllm_fmha_kernel_launcher.cu | 11 +- csrc/xqa/defines.h | 27 +- csrc/xqa/mha.cu | 218 ++------- csrc/xqa/mha.h | 71 +-- csrc/xqa/mhaUtils.cuh | 50 +- csrc/xqa/mha_sm90.cu | 254 ++--------- csrc/xqa/mla_sm120.cu | 179 ++------ csrc/xqa/tensorMap.cpp | 21 +- csrc/xqa/tensorMap.h | 3 +- csrc/xqa/utils.cuh | 2 +- csrc/xqa/xqa_wrapper.cu | 49 +- flashinfer/decode.py | 399 ++++++++++++---- flashinfer/jit/xqa.py | 4 - flashinfer/prefill.py | 27 +- flashinfer/xqa.py | 37 +- tests/attention/test_trtllm_gen_attention.py | 143 ++++-- tests/attention/test_xqa.py | 135 ++++-- tests/attention/test_xqa_batch_decode.py | 457 +++++++++++++++++++ 19 files changed, 1236 insertions(+), 876 deletions(-) create mode 100644 tests/attention/test_xqa_batch_decode.py diff --git a/csrc/flashinfer_xqa_binding.cu b/csrc/flashinfer_xqa_binding.cu index 003a23a5f6..40b4168a9b 100644 --- a/csrc/flashinfer_xqa_binding.cu +++ b/csrc/flashinfer_xqa_binding.cu @@ -18,14 +18,10 @@ #if MLA_WRAPPER void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q, -#if PAGED_KV_CACHE_LAYOUT == 1 - TensorView kCacheVLLM, TensorView vCacheVLLM, -#else - TensorView pool, -#endif - TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, - int64_t batchSize, TensorView kvCacheScale, TensorView semaphores, - TensorView scratch); + TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, + int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, + TensorView kvCacheScale, TensorView semaphores, TensorView scratch, + bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla); @@ -36,18 +32,13 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK #if LOW_PREC_OUTPUT TensorView rcpOutScale, #endif - TensorView q, tvm::ffi::Optional attentionSinks, -#if PAGED_KV_CACHE_LAYOUT == 1 - TensorView kCacheVLLM, TensorView vCacheVLLM, -#else - TensorView pool, -#endif - TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, - int64_t batchSize, TensorView kvCacheScale, + TensorView q, tvm::ffi::Optional attentionSinks, TensorView kCacheVLLM, + TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, + TensorView seqLen, int64_t batchSize, TensorView kvCacheScale, #if SPEC_DEC int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, #endif - TensorView semaphores, TensorView scratch); + TensorView semaphores, TensorView scratch, bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper); diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index c40e773e64..89d958ce7f 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -228,15 +228,17 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal TVM_FFI_ICHECK((head_dim_v == 576 && head_dim_o == 512) || head_dim_v == head_dim_o) << "head_dim_v and head_dim_o must be the same for non-MLA attention, got " << std::to_string(head_dim_v) << " and " << std::to_string(head_dim_o); - int page_size = key_cache.size(-2); - int num_kv_heads = key_cache.size(-3); int max_num_blocks_per_seq = block_tables.size(-1); bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr(); int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2; + // Assume NHD layout: [..., H, N, D] + int page_size = key_cache.size(-2); + int num_kv_heads = key_cache.size(-3); int kv_stride_keys_values = key_cache.stride(-2); // key/values int kv_stride_heads = key_cache.stride(-3); // head - int kv_stride_batch = key_cache.stride(0); // batch + + int kv_stride_batch = key_cache.stride(0); // batch const auto stream = get_stream(query.device()); void* output_sf_ptr = @@ -291,9 +293,10 @@ void trtllm_paged_attention_context(TensorView out, Optional out_sca int max_num_blocks_per_seq = block_tables.size(-1); bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr(); int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2; + + // Assume NHD layout: [..., H, N, D] int page_size = key_cache.size(-2); int num_kv_heads = key_cache.size(-3); - int kv_stride_keys_values = key_cache.stride(-2); // key/values int kv_stride_heads = key_cache.stride(-3); // head int kv_stride_batch = key_cache.stride(0); // batch diff --git a/csrc/xqa/defines.h b/csrc/xqa/defines.h index 6f0acc4f85..ca8589d808 100644 --- a/csrc/xqa/defines.h +++ b/csrc/xqa/defines.h @@ -92,21 +92,6 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena #define TOKENS_PER_PAGE 32 #endif -// don't modify -#ifndef USE_PAGED_KV_CACHE -#define USE_PAGED_KV_CACHE (TOKENS_PER_PAGE > 0) -#endif - -// Paged KV Cache Format -// 0 - XQA Original -// 1 - separate K and V cache pools, each with layout (batch, seq_len, head, head_elem) for -// VLLM/SGLang -#ifdef USE_PAGED_KV_CACHE -#ifndef PAGED_KV_CACHE_LAYOUT -#define PAGED_KV_CACHE_LAYOUT 0 -#endif -#endif - // don't modify #define USE_BEAM_SEARCH (BEAM_WIDTH > 1) @@ -129,7 +114,16 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena // 1 - naive PDL // 2 - aggressive PDL (implemented only in mha_sm90.cu for now) #ifndef ENABLE_PDL +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +#if __CUDA_ARCH__ == 900 #define ENABLE_PDL 2 +#else +#define ENABLE_PDL 1 +#endif +#else +/* default for host or older architectures */ +#define ENABLE_PDL 0 +#endif #endif #ifndef USE_INPUT_KV @@ -161,8 +155,7 @@ static_assert(CACHE_ELEM_ENUM != 0); #endif // true should be better if warpTile.x * cacheElemSize < 128. otherwise use false. -#define GRP_LOAD_V \ - (CACHE_ELEM_ENUM != 0) || (HEAD_ELEMS == 256 && USE_PAGED_KV_CACHE && BEAM_WIDTH > 1) +#define GRP_LOAD_V (CACHE_ELEM_ENUM != 0) || (HEAD_ELEMS == 256 && BEAM_WIDTH > 1) // use custom barrier for NVRTC to avoid pulling in many headers #ifndef USE_CUSTOM_BARRIER diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index cf6778fb20..1693d025f0 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -89,7 +89,7 @@ constexpr uint32_t cvtExpansion = exactDiv(inputElemSize, cacheElemSize); constexpr uint32_t preferedKHeadPartBytes = 64; __constant__ constexpr uint32_t cacheVTileSeqLen = 32; #else -#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 +#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 || __CUDA_ARCH__ == 1210 constexpr uint32_t preferedKHeadPartBytes = 64; __constant__ constexpr uint32_t cacheVTileSeqLen = 32; #elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 900 || \ @@ -293,14 +293,12 @@ constexpr uint32_t nbCacheVTilesPerXTile = exactDiv(warpTile.x, cacheVTileSeqLen constexpr uint32_t nbWarpGrpsPerXTile = mha::min(nbCacheVTilesPerXTile, gemm1NbWarpGrps); -#if USE_PAGED_KV_CACHE constexpr uint32_t nbPagesPerWarpTile = (warpTile.x <= tokensPerPage ? 1U : exactDiv(warpTile.x, tokensPerPage)); using KCachePageIndices = Vec; constexpr uint32_t nbPagesPerVTile = (cacheVTileSeqLen <= tokensPerPage ? 1 : exactDiv(cacheVTileSeqLen, tokensPerPage)); using VCachePageIndices = Vec; -#endif static_assert(ctaShapeInWarps.y == 1); @@ -336,10 +334,8 @@ struct alignas(128) SharedMem { #if BEAM_WIDTH > 1 Vec gemm0CacheIndir[ctaShapeInWarps.x]; Vec gemm1CacheIndir[grpLoadV ? gemm1NbWarpGrps : ctaShapeInWarps.x]; -#if USE_PAGED_KV_CACHE Vec kCachePages[ctaShapeInWarps.x]; Vec vCachePages[grpLoadV ? gemm1NbWarpGrps : ctaShapeInWarps.x]; -#endif #endif using Barrier = CtaBarrier; @@ -1307,6 +1303,7 @@ CUBIN_EXPORT __global__ uint32_t const batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V // cache. Used only for int8/fp8 KV cache. + uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head, uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) { assert(allowMultiBlockMode || gridDim.x == 1); bool const isMultiBlock = allowMultiBlockMode && (gridDim.x != 1); @@ -1483,10 +1480,8 @@ CUBIN_EXPORT __global__ #endif uint32_t const nbSkipLeadingTiles = nbTotalSkipTokens / ctaTile.x; uint32_t const tile0NbSkipTokens = nbTotalSkipTokens % ctaTile.x; -#if USE_PAGED_KV_CACHE uint32_t const nbPages = divUp(cacheSeqLen, tokensPerPage); constexpr uint32_t nbPagesPerCtaTile = exactDiv(ctaTile.x, tokensPerPage); -#endif uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0; #if SPEC_DEC @@ -1523,7 +1518,6 @@ CUBIN_EXPORT __global__ }; loadCacheIndir(seqIterInit, 0U); #endif -#if USE_PAGED_KV_CACHE #if BEAM_WIDTH == 1 KCachePageIndices pageIdx = KCachePageIndices::filled(kBAD_PAGE_INDEX); #endif @@ -1539,11 +1533,6 @@ CUBIN_EXPORT __global__ }; uint32_t idxPageBeg = nbPagesPerCtaTile * seqIterInit + warpIdx.x * warpTile.x / tokensPerPage; loadPages(idxPageBeg); -#else - constexpr uint32_t idxBeamBase = 0U; - uint32_t const cacheKSeqBaseOffset = - cacheList.capacity * (idxHeadGrp + nbKHeads * 2 * (idxBeamBase + beamWidth * idxReq)); -#endif auto loadKTilePart = [&](uint32_t seqIter, uint32_t idxBeam, uint32_t idxPart) mutable { assert(idxBeam < beamWidth); assert(seqIter % nbSubSeqPerSeq == seqIterInit % nbSubSeqPerSeq); @@ -1551,46 +1540,22 @@ CUBIN_EXPORT __global__ auto& dst = getSMemKTile(idxNextSMemKBuf); uint32_t const dstHeadOffset = 0; uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * warpIdx.x; -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 - uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp; + uint32_t const tokenOffset = seqOffset % tokensPerPage; -#else - uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage; -#endif #if BEAM_WIDTH == 1 -#if PAGED_KV_CACHE_LAYOUT == 1 HeadPtr const src{ - cacheList.kCacheVLLM, pageIdx, nbKHeads, idxHeadBeg}; -#else - HeadPtr const src{ - cacheList.pool, pageIdx, nbKHeads, idxHeadBeg}; -#endif + cacheList.kCacheVLLM, pageIdx, tokenOffset, idxHeadGrp, + kv_stride_page, kv_stride_token, kv_stride_head}; #else IndexedHeadPtr const src{ /*indices=*/smem.gemm0CacheIndir[warpIdx.x].data, -#if PAGED_KV_CACHE_LAYOUT == 1 /*pool=*/cacheList.kCacheVLLM, -#else - /*pool=*/cacheList.pool, -#endif /*pageIndices=*/smem.kCachePages[warpIdx.x].data, - /*nbKHeads=*/nbKHeads, - /*offset=*/idxHeadBeg}; -#endif -#else - uint32_t const idxHeadBeg = cacheKSeqBaseOffset + seqOffset; -#if BEAM_WIDTH == 1 - TinyPtr const src{cacheList.data, idxHeadBeg}; -#else - IndexedHeadPtr const src{ - /*indices=*/smem.gemm0CacheIndir[warpIdx.x].data, - /*pointer=*/cacheList.data, - /*offset=*/idxHeadBeg, - /*beamStride=*/cacheList.capacity * nbKHeads * 2}; - // trap(); - // assert("not implemented"); -#endif + /*tokenOffset=*/tokenOffset, + /*headIdx=*/idxHeadGrp, + /*stride_page=*/kv_stride_page, + /*stride_token=*/kv_stride_token, + /*stride_head=*/kv_stride_head}; #endif // if (threadIdx.x == dbgPrintTid) { // printf("K: seqIter=%u, idxBeam=%u, idxPart=%u: pointers={%p, %p}, indices={", seqIter, @@ -1618,13 +1583,11 @@ CUBIN_EXPORT __global__ __syncwarp(); #endif if (idxPart + 1 == nbPartsPerCacheKHead) { -#if USE_PAGED_KV_CACHE bool const isForNextSeqIter = isConvergedTile(seqIter) || idxBeam == beamWidth - 1; if (isForNextSeqIter) { idxPageBeg += nbPagesPerCtaTile * nbSubSeqPerSeq; loadPages(idxPageBeg); } -#endif #if BEAM_WIDTH > 1 uint32_t idxBeamNext, seqIterDelta; mha::tie(idxBeamNext, seqIterDelta) = @@ -1831,7 +1794,6 @@ CUBIN_EXPORT __global__ auto const getSmemVBar = [&](uint32_t idx) -> SharedMem::Barrier* { return smem.vBarrier(warpGrpIdx, idx); }; -#if USE_PAGED_KV_CACHE #if BEAM_WIDTH == 1 VCachePageIndices pageIdx = VCachePageIndices::filled(kBAD_PAGE_INDEX); #endif @@ -1849,12 +1811,6 @@ CUBIN_EXPORT __global__ uint32_t idxPageBeg = nbPagesPerCtaTile * seqIterInit + cacheVTileSeqLen * warpGrpIdx / tokensPerPage; loadPages(idxPageBeg); -#else - uint32_t const idxBeamBase = 0; - uint32_t const cacheVSeqBaseOffset = - cacheList.capacity * - (nbKHeads + idxHeadGrp + nbKHeads * 2 * (idxBeamBase + beamWidth * idxReq)); -#endif auto nextStep = [&](uint32_t seqIter, uint32_t xIter, uint32_t vIter, uint32_t idxBeam) { uint32_t vIterNext, isNextBeam; mha::tie(vIterNext, isNextBeam) = carryLE(vIter + 1, 0); @@ -1881,44 +1837,22 @@ CUBIN_EXPORT __global__ uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * nbXTilesPerXIter * xIter + cacheVTileSeqStride * vIter + cacheVTileSeqLen * warpGrpIdx; -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 - uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp; + uint32_t const tokenOffset = seqOffset % tokensPerPage; -#else - uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage; -#endif #if BEAM_WIDTH == 1 -#if PAGED_KV_CACHE_LAYOUT == 1 - HeadPtr const src{ - cacheList.vCacheVLLM, pageIdx, nbKHeads, idxHeadBeg}; -#else HeadPtr const src{ - cacheList.pool, pageIdx, nbKHeads, idxHeadBeg}; -#endif + cacheList.vCacheVLLM, pageIdx, tokenOffset, idxHeadGrp, + kv_stride_page, kv_stride_token, kv_stride_head}; #else IndexedHeadPtr const src{ /*indices=*/smem.gemm1CacheIndir[grpLoadV ? warpGrpIdx : warpIdx.x].data, -#if PAGED_KV_CACHE_LAYOUT == 1 /*pool=*/cacheList.vCacheVLLM, -#else - /*pool=*/cacheList.pool, -#endif /*pageIndices=*/smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x].data, - /*nbKHeads=*/nbKHeads, - /*offset=*/idxHeadBeg}; -#endif -#else - uint32_t const idxHeadBeg = cacheVSeqBaseOffset + seqOffset; -#if BEAM_WIDTH == 1 - TinyPtr const src{cacheList.data, idxHeadBeg}; -#else - IndexedHeadPtr const src{ - /*indices=*/smem.gemm1CacheIndir[grpLoadV ? warpGrpIdx : warpIdx.x].data, - /*pointer=*/cacheList.data, - /*offset=*/idxHeadBeg, - /*beamStride=*/cacheList.capacity * nbKHeads * 2}; -#endif + /*tokenOffset=*/tokenOffset, + /*headIdx=*/idxHeadGrp, + /*stride_page=*/kv_stride_page, + /*stride_token=*/kv_stride_token, + /*stride_head=*/kv_stride_head}; #endif // if (threadIdx.x == dbgPrintTid) { // printf("V: seqIter=%u, xIter=%u, idxBeam=%u, vIter=%u: pointers={%p, %p}, indices={", @@ -1963,7 +1897,6 @@ CUBIN_EXPORT __global__ unused(arrive(pWarpGrpBar)); wait_parity(pWarpGrpBar, getAndFlip(warpGrpBarParityNext)); #endif -#if USE_PAGED_KV_CACHE constexpr uint32_t xIterSeqStride = cacheVTileSeqStride * nbVItersPerXIter; if constexpr (xIterSeqStride <= tokensPerPage) { uint32_t const nbXItersPerPage = exactDiv(tokensPerPage, xIterSeqStride); @@ -1990,7 +1923,6 @@ CUBIN_EXPORT __global__ loadPages(idxPageBeg); } } -#endif #if BEAM_WIDTH > 1 uint32_t seqIterNext, xIterNext, vIterNext, idxBeamNext; mha::tie(seqIterNext, xIterNext, vIterNext, idxBeamNext) = @@ -2480,6 +2412,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( uint32_t const batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. // Used only for int8/fp8 KV cache. + uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head, uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) { #if SPEC_DEC kernel_mha_impl(qSeqLen, nbKHeads, headGrpSize, qCuSeqLens, @@ -2501,7 +2434,8 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( #if BEAM_WIDTH > 1 beamSearchParams, #endif - batchSize, kvCacheScale, semaphores, scratch); + batchSize, kvCacheScale, kv_stride_page, kv_stride_token, kv_stride_head, + semaphores, scratch); } #else static constexpr auto kernel_mha = kernel_mha_impl; @@ -2526,18 +2460,10 @@ void launchMHA( InputHead const* q, #endif float const* attentionSinks, // [headGrpSize] -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, // global pool of pages -#endif KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq]. -#else - GMemKVCacheHead* kvCacheData, -#endif uint32_t maxSeqLen, uint32_t const* seqLen, #if BEAM_WIDTH > 1 BeamSearchParams const& beamSearchParams, @@ -2548,7 +2474,8 @@ void launchMHA( #if SPEC_DEC SpecDecParams const& specDecParams, #endif - uint32_t* semaphores, void* scratch, cudaStream_t stream) { + uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, + uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) { #if SPEC_DEC auto const qSeqLen = specDecParams.qSeqLen; auto const qCuSeqLens = specDecParams.qCuSeqLens; @@ -2590,15 +2517,15 @@ void launchMHA( dim3 const dimGrid{nbSubSeqPerSeq, nbKHeads, batchSize}; #endif dim3 const dimCta{warp_size * ctaShapeInWarps.x, ctaShapeInWarps.y, ctaShapeInWarps.z}; - auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); -#if USE_PAGED_KV_CACHE + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); -#if PAGED_KV_CACHE_LAYOUT == 1 KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; -#else - KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; -#endif + // Convert stride from elements to Heads + uint32_t const stride_page_in_heads = static_cast(kv_stride_page / validElemsPerHead); + uint32_t const stride_token_in_heads = static_cast(kv_stride_token / validElemsPerHead); + uint32_t const stride_head_in_heads = static_cast(kv_stride_head / validElemsPerHead); + cudaLaunchKernelEx(&launchCfg, kernel_mha, #if SPEC_DEC qSeqLen, nbKHeads, headGrpSize, qCuSeqLens, @@ -2620,36 +2547,8 @@ void launchMHA( #if BEAM_WIDTH > 1 beamSearchParams, #endif - batchSize, kvCacheScale, semaphores, scratch); -#else - KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; -#ifndef NDEBUG - kernel_mha<<>>( -#else - cudaLaunchKernelEx(&launchCfg, &kernel_mha, -#endif -#if SPEC_DEC - qSeqLen, nbKHeads, headGrpSize, qCuSeqLens, -#else - nbKHeads, -#endif -#if SLIDING_WINDOW - slidingWinSize, -#endif - qScale, output, -#if LOW_PREC_OUTPUT - rcpOutScale, -#endif - q, -#if SPEC_DEC - mask, -#endif - attentionSinks, cacheList, -#if BEAM_WIDTH > 1 - beamSearchParams, -#endif - batchSize, kvCacheScale, semaphores, scratch); -#endif + batchSize, kvCacheScale, stride_page_in_heads, stride_token_in_heads, + stride_head_in_heads, semaphores, scratch); checkCuda(cudaPeekAtLastError()); #endif // USE_INPUT_KV } @@ -2669,19 +2568,16 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 #if LOW_PREC_OUTPUT float const* rcpOutScale, #endif - InputHead const* q, float const* attentionSinks, -#if PAGED_KV_CACHE_LAYOUT == 1 - GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, -#endif - KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, - uint32_t const* seqLen, uint32_t batchSize, + InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, + GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif - uint32_t* semaphores, void* scratch, cudaStream_t stream) { + uint32_t* semaphores, void* scratch, bool enable_pdl, + uint64_t kv_stride_page, uint64_t kv_stride_token, uint64_t kv_stride_head, + cudaStream_t stream) { uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { if (!allowMultiBlockMode) { return 1; @@ -2696,15 +2592,15 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 dim3 const dimGrid{nbSubSeqPerSeq, nbKHeads, batchSize}; #endif dim3 const dimCta{warp_size * ctaShapeInWarps.x, ctaShapeInWarps.y, ctaShapeInWarps.z}; - auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); -#if USE_PAGED_KV_CACHE + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); -#if PAGED_KV_CACHE_LAYOUT == 1 KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; -#else - KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; -#endif + // Convert stride from elements to Heads + uint32_t const stride_page_in_heads = static_cast(kv_stride_page / validElemsPerHead); + uint32_t const stride_token_in_heads = static_cast(kv_stride_token / validElemsPerHead); + uint32_t const stride_head_in_heads = static_cast(kv_stride_head / validElemsPerHead); + cudaLaunchKernelEx(&launchCfg, kernel_mha, #if SPEC_DEC qSeqLen, nbKHeads, headGrpSize, qCuSeqLens, @@ -2722,32 +2618,8 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 #if SPEC_DEC mask, #endif - attentionSinks, cacheList, batchSize, kvCacheScale, semaphores, scratch); -#else - KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; -#ifndef NDEBUG - kernel_mha<<>>( -#else - cudaLaunchKernelEx(&launchCfg, &kernel_mha, -#endif -#if SPEC_DEC - qSeqLen, nbKHeads, headGrpSize, qCuSeqLens, -#else - nbKHeads, -#endif -#if SLIDING_WINDOW - slidingWinSize, -#endif - qScale, output, -#if LOW_PREC_OUTPUT - rcpOutScale, -#endif - q, -#if SPEC_DEC - mask, -#endif - attentionSinks, cacheList, batchSize, kvCacheScale, semaphores, scratch); -#endif + attentionSinks, cacheList, batchSize, kvCacheScale, stride_page_in_heads, + stride_token_in_heads, stride_head_in_heads, semaphores, scratch); checkCuda(cudaPeekAtLastError()); } #endif diff --git a/csrc/xqa/mha.h b/csrc/xqa/mha.h index d50c081b6a..43aed55f95 100644 --- a/csrc/xqa/mha.h +++ b/csrc/xqa/mha.h @@ -46,7 +46,7 @@ constexpr bool useKVCache = USE_KV_CACHE; using SeqLenDataType = uint32_t; -constexpr bool usePagedKVCache = USE_PAGED_KV_CACHE; +constexpr bool usePagedKVCache = true; constexpr uint32_t tokensPerPage = TOKENS_PER_PAGE; using IOHead = Vec; @@ -106,18 +106,10 @@ void launchMHA( InputHead const* q, #endif float const* attentionSinks, // [headGrpSize] -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, // global pool of pages -#endif KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] -#else - GMemKVCacheHead* kvCacheData, -#endif uint32_t maxSeqLen, uint32_t const* seqLen, #if BEAM_WIDTH > 1 BeamSearchParams const& beamSearchParams, @@ -128,26 +120,24 @@ void launchMHA( #if SPEC_DEC SpecDecParams const& specDecParams, #endif - uint32_t* semaphores, void* scratch, cudaStream_t stream); + uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, + uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream); void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, float qScale, OutputHead* output, #if LOW_PREC_OUTPUT float const* rcpOutScale, #endif - InputHead const* q, float const* attentionSinks, -#if PAGED_KV_CACHE_LAYOUT == 1 - GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, -#endif - KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, - uint32_t const* seqLen, uint32_t batchSize, + InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, + GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif - uint32_t* semaphores, void* scratch, cudaStream_t stream); + uint32_t* semaphores, void* scratch, bool enable_pdl, + uint64_t kv_stride_page, uint64_t kv_stride_token, uint64_t kv_stride_head, + cudaStream_t stream); void launchHopperF8MHA( cudaDeviceProp const& prop, uint32_t nbKHeads, @@ -167,18 +157,10 @@ void launchHopperF8MHA( InputHead const* q, #endif float const* attentionSinks, // [headGrpSize] -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, // global pool of pages -#endif KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq]. -#else - GMemKVCacheHead* kvCacheData, -#endif uint32_t maxSeqLen, uint32_t const* seqLen, #if BEAM_WIDTH > 1 BeamSearchParams const& beamSearchParams, @@ -189,7 +171,7 @@ void launchHopperF8MHA( #if SPEC_DEC SpecDecParams const& specDecParams, #endif - uint32_t* semaphores, void* scratch, cudaStream_t stream); + uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream); void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, float qScale, OutputHead* output, @@ -197,50 +179,36 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads float const* rcpOutScale, #endif InputHead const* q, float const* attentionSinks, -#if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, -#endif KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif - uint32_t* semaphores, void* scratch, cudaStream_t stream); + uint32_t* semaphores, void* scratch, bool enable_pdl, + uint64_t kv_stride_page, uint64_t kv_stride_token, + uint64_t kv_stride_head, cudaStream_t stream); void launchMLA( cudaDeviceProp const& prop, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed - float qScale, OutputHead* output, InputHead const* q, -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 - GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, // global pool of pages -#endif + float qScale, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM, + GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or // [batchSize][maxNbPagesPerSeq] (Layout 1) -#else - GMemKVCacheHead* kvCacheData, -#endif uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. // Used only for int8/fp8 KV cache. - uint32_t* semaphores, void* scratch, cudaStream_t stream); + uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream); void launchMLAFlashInfer( uint32_t multiProcessorCount, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed - float qScale, OutputHead* output, InputHead const* q, -#if PAGED_KV_CACHE_LAYOUT == 1 - GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, // global pool of pages -#endif + float qScale, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM, + GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or @@ -248,7 +216,8 @@ void launchMLAFlashInfer( uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. // Used only for int8/fp8 KV cache. - uint32_t* semaphores, void* scratch, cudaStream_t stream); + uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, + uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream); #if STATIC_NB_K_HEADS constexpr uint32_t nbKHeads = NB_K_HEADS; diff --git a/csrc/xqa/mhaUtils.cuh b/csrc/xqa/mhaUtils.cuh index 5a4bf4f8f5..869862f204 100644 --- a/csrc/xqa/mhaUtils.cuh +++ b/csrc/xqa/mhaUtils.cuh @@ -22,16 +22,21 @@ struct IndexedHeadPtrImpl { uint32_t const* indices; // values are in range [0, beamWidth) Head* pool; Vec const* pageIndices; - uint32_t nbKHeads; - uint32_t offset; // applied onto pool + pointers + uint32_t tokenOffset; // token offset within the first page + uint32_t headIdx; // head index + uint32_t stride_page; // stride for each page (in units of Head) + uint32_t stride_token; // stride for each token (in units of Head) + uint32_t stride_head; // stride for each head (in units of Head) __device__ inline Head& operator[](uint32_t i) const { return *(*this + i); } __device__ inline Head* operator+(uint32_t i) const { - assert(indices[i] < beamWidth); - assert(nbPages == 1 || offset % tokensPerPage == 0); - auto const pageIdx = pageIndices[indices[i]][nbPages == 1 ? 0U : i / tokensPerPage]; - return pool + (tokensPerPage * nbKHeads * pageIdx + offset + i % tokensPerPage); + uint32_t const beamIdx = indices[i]; + assert(beamIdx < beamWidth); + uint32_t const absoluteTokenIdx = tokenOffset + i; + auto const pageIdx = pageIndices[beamIdx][nbPages == 1 ? 0U : absoluteTokenIdx / tokensPerPage]; + return pool + pageIdx * stride_page + (absoluteTokenIdx % tokensPerPage) * stride_token + + headIdx * stride_head; } }; @@ -59,24 +64,21 @@ struct HeadPtr { static_assert(tokensPerPage != 0 && nbPages != 0); Head* pool; Vec pageIndices; - uint32_t nbKHeads; - uint32_t offset; // offset inside the first page. + uint32_t tokenOffset; // token offset within the first page + uint32_t headIdx; // head index + uint32_t stride_page; // stride for each page (in units of Head) + uint32_t stride_token; // stride for each token (in units of Head) + uint32_t stride_head; // stride for each head (in units of Head) __device__ inline Head& operator[](uint32_t i) const { return *(*this + i); } __device__ inline Head* operator+(uint32_t i) const { -#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE - auto const pageIdx = pageIndices[nbPages == 1 ? 0U : i / tokensPerPage]; - return (pageIdx & (1U << 31)) ? nullptr - : pool + (tokensPerPage * nbKHeads * pageIdx + offset + - (i % tokensPerPage) * nbKHeads); -#else - assert(nbPages == 1 || offset % tokensPerPage == 0); - auto const pageIdx = pageIndices[nbPages == 1 ? 0U : i / tokensPerPage]; + uint32_t const absoluteTokenIdx = tokenOffset + i; + auto const pageIdx = pageIndices[nbPages == 1 ? 0U : absoluteTokenIdx / tokensPerPage]; return (pageIdx & (1U << 31)) ? nullptr - : pool + (tokensPerPage * nbKHeads * pageIdx + offset + i % tokensPerPage); -#endif + : pool + pageIdx * stride_page + (absoluteTokenIdx % tokensPerPage) * stride_token + + headIdx * stride_head; } }; @@ -226,12 +228,8 @@ struct KVCacheList; template <> struct KVCacheList { -#if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM; GMemCacheHead* vCacheVLLM; -#else - GMemKVCacheHead* pool; -#endif KVCachePageIndex const* kvCachePageList; // shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq]. SeqLenDataType const* seqLenList; // shape: [batchSize][beamWidth] (for compatibility) @@ -279,16 +277,8 @@ __device__ inline Vec getPage(KVCacheList #pragma unroll for (uint32_t i = 0; i < nbLoadedPages; i++) { uint32_t const idxPage = idxPageBeg + i; -#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[maxNbPagesPerSeq * idxReq + idxPage] : kBAD_PAGE_INDEX); -#else - ret[i] = - (idxPage < nbPages ? cacheList.kvCachePageList[beamWidth * 2 * maxNbPagesPerSeq * idxReq + - 2 * maxNbPagesPerSeq * idxBeam + - maxNbPagesPerSeq * (isK ? 0U : 1U) + idxPage] - : kBAD_PAGE_INDEX); -#endif } return ret; } diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu index 0ba6dad585..495f4d2d46 100644 --- a/csrc/xqa/mha_sm90.cu +++ b/csrc/xqa/mha_sm90.cu @@ -201,11 +201,9 @@ struct alignas(128) SharedMem { ShmQWiseVec gemm1AccColMax; ShmQWiseVec gemm1AccColSum; -#if USE_PAGED_KV_CACHE static constexpr uint32_t nbPagesPerTile = gemm0CtaTileNbTokens >= tokensPerPage ? exactDiv(gemm0CtaTileNbTokens, tokensPerPage) : 1; Vec pages[2]; // one for K and one for V -#endif // mem barriers @@ -271,11 +269,9 @@ struct KVTilePartLoader { static constexpr uint32_t nbParts = cacheHeadNbParts; static constexpr uint32_t partElems = exactDiv(headElems, nbParts); -#if USE_PAGED_KV_CACHE static_assert(gemm0CtaTileNbTokens % tokensPerPage == 0 || tokensPerPage % gemm0CtaTileNbTokens == 0); static constexpr uint32_t nbPagesPerTile = SharedMem::nbPagesPerTile; -#endif uint32_t const nbKHeads; KVCacheList const& cacheList; @@ -283,21 +279,15 @@ struct KVTilePartLoader { uint32_t const idxHeadGrp; CUtensorMap const& tensorMap; -#if USE_PAGED_KV_CACHE uint32_t const nbPages; // for bound check Vec& pages; uint32_t idxTileRef; // idxTile used to load the pages -#endif uint32_t const baseOffset; __device__ KVTilePartLoader(bool isK, uint32_t nbKHeads, KVCacheList const& cacheList, uint32_t idxReq, - uint32_t idxHeadGrp, CUtensorMap const& tensorMap -#if USE_PAGED_KV_CACHE - , - uint32_t nbPages, Vec& pageBuf -#endif - ); + uint32_t idxHeadGrp, CUtensorMap const& tensorMap, uint32_t nbPages, + Vec& pageBuf); // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache template __device__ void loadData( @@ -638,12 +628,8 @@ __launch_bounds__(128 * 3) uint32_t const batchSize, float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and // V cache. Used only for int8/fp8 KV cache. -#if PAGED_KV_CACHE_LAYOUT == 1 __grid_constant__ CUtensorMap const tensorMapVLLMK, __grid_constant__ CUtensorMap const tensorMapVLLMV, -#else - __grid_constant__ CUtensorMap const tensorMap, -#endif #if SPEC_DEC SpecDecParams const specDecParams, #endif @@ -733,16 +719,10 @@ __launch_bounds__(128 * 3) uint32_t const ctaInputTokBeg = reqInputTokBeg + ctaTokOffset; auto const warpIdx = getWarpIdx(uint3{128, 1, 3}); auto const wid = warpIdx.z * 4 + warpIdx.x; -#if PAGED_KV_CACHE_LAYOUT == 1 if (wid == 0 && warpElectSync()) { tma::prefetchTensorMap(tensorMapVLLMK); tma::prefetchTensorMap(tensorMapVLLMV); } -#else - if (wid == 0 && warpElectSync()) { - tma::prefetchTensorMap(tensorMap); - } -#endif extern __shared__ char smemByteBuf[]; assert(dynamicSmemSize() >= sizeof(SharedMem)); SharedMem& smem = *reinterpret_cast(&smemByteBuf[0]); @@ -768,9 +748,7 @@ __launch_bounds__(128 * 3) } __syncthreads(); -#if USE_PAGED_KV_CACHE uint32_t const nbPages = divUp(cacheSeqLen, tokensPerPage); -#endif constexpr bool isKVCacheQuantized = (cacheElemSize < 2); assert(idxKTileInit < nbTiles); @@ -1316,18 +1294,8 @@ __launch_bounds__(128 * 3) asm volatile("fence.proxy.async.shared::cta;\n"); unused(smem.qBar.produced.arrive()); } else if (warpIdx.x == nbQLdWarps) { // load k - KVTilePartLoader kTilePartLoader{true, nbKHeads, cacheList, idxReq, idxHeadGrp, -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 - tensorMapVLLMK, -#else - tensorMap, -#endif - nbPages, smem.pages[0] -#else - tensorMap -#endif - }; + KVTilePartLoader kTilePartLoader{true, nbKHeads, cacheList, idxReq, + idxHeadGrp, tensorMapVLLMK, nbPages, smem.pages[0]}; for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; kTilePartLoader.loadPages(idxKTile); @@ -1385,18 +1353,8 @@ __launch_bounds__(128 * 3) } } } else if (warpIdx.x == nbQLdWarps + 1) { // load v - KVTilePartLoader vTileLoader{false, nbKHeads, cacheList, idxReq, idxHeadGrp, -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 - tensorMapVLLMV, -#else - tensorMap, -#endif - nbPages, smem.pages[1] -#else - tensorMap -#endif - }; + KVTilePartLoader vTileLoader{false, nbKHeads, cacheList, idxReq, + idxHeadGrp, tensorMapVLLMV, nbPages, smem.pages[1]}; for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq; vTileLoader.loadPages(idxVTile); @@ -1730,35 +1688,16 @@ __device__ inline void F16QToF8Converter::store( __device__ inline KVTilePartLoader::KVTilePartLoader(bool isK, uint32_t nbKHeads, KVCacheList const& cacheList, uint32_t idxReq, uint32_t idxHeadGrp, - CUtensorMap const& tensorMap -#if USE_PAGED_KV_CACHE - , - uint32_t nbPages, - Vec& pageBuf -#endif - ) + CUtensorMap const& tensorMap, uint32_t nbPages, + Vec& pageBuf) : nbKHeads{nbKHeads}, cacheList{cacheList}, idxReq{idxReq}, idxHeadGrp{idxHeadGrp}, - tensorMap{tensorMap} -#if USE_PAGED_KV_CACHE - , + tensorMap{tensorMap}, nbPages{nbPages}, - pages{pageBuf} -#if PAGED_KV_CACHE_LAYOUT == 1 - , - baseOffset{idxReq * cacheList.maxNbPagesPerSeq} -#else - , - baseOffset{((idxReq * beamWidth) * 2 + (isK ? 0 : 1)) * cacheList.maxNbPagesPerSeq} -#endif -#else - , - baseOffset{(idxReq * beamWidth) * 2 + (isK ? 0 : 1)} -#endif -{ -} + pages{pageBuf}, + baseOffset{idxReq * cacheList.maxNbPagesPerSeq} {} // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache template @@ -1766,38 +1705,22 @@ __device__ inline void KVTilePartLoader::loadData( Array2D& dst, uint32_t idxTile, uint32_t idxPart, CtaBarrier& bar) { static_assert(nbTokens == gemm0CtaTileNbTokens); -#if USE_PAGED_KV_CACHE assert(idxTile == idxTileRef); if constexpr (nbTokens < tokensPerPage) { assert(nbPagesPerTile == 1); uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); -#if PAGED_KV_CACHE_LAYOUT == 1 tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t)pages[0]}, bar); -#else - tma::loadAsync(&dst, tensorMap, - DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t)pages[0]}, bar); -#endif } else { #pragma unroll for (uint32_t i = 0; i < nbPagesPerTile; i++) { -#if PAGED_KV_CACHE_LAYOUT == 1 tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, DimsLE<4>{partElems * idxPart, idxHeadGrp, 0, (uint32_t)pages[i]}, bar); -#else - tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, - DimsLE<4>{partElems * idxPart, 0, idxHeadGrp, (uint32_t)pages[i]}, bar); -#endif } } -#else - tma::loadAsync(&dst, tensorMap, - DimsLE<4>{partElems * idxPart, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); -#endif } __device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile) { -#if USE_PAGED_KV_CACHE uint32_t const idxPageBeg = gemm0CtaTileNbTokens >= tokensPerPage ? nbPagesPerTile * idxTile : idxTile / exactDiv(tokensPerPage, gemm0CtaTileNbTokens); @@ -1812,28 +1735,13 @@ __device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile) { } idxTileRef = idxTile; __syncwarp(); -#endif } __device__ inline GMemKVCacheHead& KVTilePartLoader::getHead(uint32_t pos) { constexpr uint32_t nbTokens = gemm0CtaTileNbTokens; -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 // Raise a runtime error indicating not implemented - assert(false && "KVTilePartLoader::getHead is not implemented for PAGED_KV_CACHE_LAYOUT == 1"); + assert(false && "KVTilePartLoader::getHead is not implemented"); __trap(); -#else - uint32_t const idxTile = pos / nbTokens; - assert(idxTile == idxTileRef); - uint32_t const offset = pos % tokensPerPage; - return cacheList - .pool[tokensPerPage * (nbKHeads * pages[pos % nbTokens / tokensPerPage] + idxHeadGrp) + - offset]; -#endif -#else - // shape: KVCacheHead[batchSize][beamWidth][2][nbKHeads][capacity] - return cacheList.data[cacheList.capacity * (baseOffset * nbKHeads + idxHeadGrp) + pos]; -#endif } #if SWAP_AB @@ -1966,9 +1874,12 @@ __device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gme for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) { static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); - ret[i] = reinterpret_cast, - exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( - gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; + uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound); + uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols; +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + ret[i][j] = gmemVec[baseOffset + j]; + } } return ret; } @@ -3011,18 +2922,10 @@ void launchHopperF8MHA( InputHead const* q, #endif float const* attentionSinks, // [headGrpSize] -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, // global pool of pages -#endif KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] -#else - GMemKVCacheHead* kvCacheData, -#endif uint32_t maxSeqLen, uint32_t const* seqLen, #if USE_BEAM_SEARCH BeamSearchParams const& beamSearchParams, @@ -3033,7 +2936,8 @@ void launchHopperF8MHA( #if SPEC_DEC SpecDecParams const& specDecParams, #endif - uint32_t* semaphores, void* scratch, cudaStream_t stream) { + uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, + uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) { if (beamWidth != 1) { throw std::runtime_error("not implemented"); } @@ -3070,8 +2974,7 @@ void launchHopperF8MHA( // nbInputSeqSplit dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; - auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); -#if USE_PAGED_KV_CACHE + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); auto const dtype = [] { if (std::is_same_v) { @@ -3084,62 +2987,18 @@ void launchHopperF8MHA( throw std::runtime_error("unsupported cache element type"); }(); -#if PAGED_KV_CACHE_LAYOUT == 1 KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; - auto const tensorMapVLLMK = - makeTensorMapForPagedKVCache(kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, - cacheHeadPartElems, gemm0CtaTileNbTokens); - auto const tensorMapVLLMV = - makeTensorMapForPagedKVCache(vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, - cacheHeadPartElems, gemm0CtaTileNbTokens); -#else - KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; - auto const tensorMap = - makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, - cacheHeadPartElems, gemm0CtaTileNbTokens); -#endif + auto const tensorMapVLLMK = makeTensorMapForPagedKVCache( + kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, + gemm0CtaTileNbTokens, kv_stride_page, kv_stride_token, kv_stride_head); + auto const tensorMapVLLMV = makeTensorMapForPagedKVCache( + vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, + gemm0CtaTileNbTokens, kv_stride_page, kv_stride_token, kv_stride_head); - cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, -#if SLIDING_WINDOW - slidingWinSize, -#endif - qScale, output, -#if LOW_PREC_OUTPUT - rcpOutScale, -#endif -#if USE_INPUT_KV - qkv, -#if ROPE_STYLE != 0 - ropeCosSin, -#endif -#else - q, -#endif - attentionSinks, cacheList, -#if USE_BEAM_SEARCH - beamSearchParams, -#endif - batchSize, kvCacheScale, -#if PAGED_KV_CACHE_LAYOUT == 1 - tensorMapVLLMK, tensorMapVLLMV, -#else - tensorMap, -#endif -#if SPEC_DEC - specDecParams, -#endif - semaphores, scratch); -#else - KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; - static_assert(!usePagedKVCache); - assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); - auto const tensorMap = makeTensorMapForContiguousKVCache( - kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, - batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); cudaError_t const err = - cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, + cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, #if SLIDING_WINDOW slidingWinSize, #endif @@ -3159,8 +3018,11 @@ void launchHopperF8MHA( #if USE_BEAM_SEARCH beamSearchParams, #endif - batchSize, kvCacheScale, tensorMap, semaphores, scratch); + batchSize, kvCacheScale, tensorMapVLLMK, tensorMapVLLMV, +#if SPEC_DEC + specDecParams, #endif + semaphores, scratch); checkCuda(err); } #endif @@ -3180,18 +3042,16 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads float const* rcpOutScale, #endif InputHead const* q, float const* attentionSinks, -#if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, -#else - GMemCacheHead* pool, -#endif KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif - uint32_t* semaphores, void* scratch, cudaStream_t stream) { + uint32_t* semaphores, void* scratch, bool enable_pdl, + uint64_t kv_stride_page, uint64_t kv_stride_token, + uint64_t kv_stride_head, cudaStream_t stream) { uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { float const factor = 0.25f; return mha::min( @@ -3207,8 +3067,7 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads #endif dim3 const dimGrid{divUp(qLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; - auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); -#if USE_PAGED_KV_CACHE + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); auto const dtype = [] { if (std::is_same_v) { @@ -3221,22 +3080,15 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads throw std::runtime_error("unsupported cache element type"); }(); -#if PAGED_KV_CACHE_LAYOUT == 1 KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; - auto const tensorMapVLLMK = - makeTensorMapForPagedKVCache(kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, - cacheHeadPartElems, gemm0CtaTileNbTokens); - auto const tensorMapVLLMV = - makeTensorMapForPagedKVCache(vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, - cacheHeadPartElems, gemm0CtaTileNbTokens); -#else - KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; - auto const tensorMap = - makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, - cacheHeadPartElems, gemm0CtaTileNbTokens); -#endif + auto const tensorMapVLLMK = makeTensorMapForPagedKVCache( + kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, + gemm0CtaTileNbTokens, kv_stride_page, kv_stride_token, kv_stride_head); + auto const tensorMapVLLMV = makeTensorMapForPagedKVCache( + vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, + gemm0CtaTileNbTokens, kv_stride_page, kv_stride_token, kv_stride_head); cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, #if SLIDING_WINDOW @@ -3247,33 +3099,11 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads rcpOutScale, #endif q, attentionSinks, cacheList, batchSize, kvCacheScale, -#if PAGED_KV_CACHE_LAYOUT == 1 tensorMapVLLMK, tensorMapVLLMV, -#else - tensorMap, -#endif #if SPEC_DEC specDecParams, #endif semaphores, scratch); -#else - KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; - static_assert(!usePagedKVCache); - assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); - auto const tensorMap = makeTensorMapForContiguousKVCache( - kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, - batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); - cudaError_t const err = cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, -#if SLIDING_WINDOW - slidingWinSize, -#endif - qScale, output, -#if LOW_PREC_OUTPUT - rcpOutScale, -#endif - q, attentionSinks, cacheList, batchSize, kvCacheScale, - tensorMap, semaphores, scratch); -#endif checkCuda(err); } #endif diff --git a/csrc/xqa/mla_sm120.cu b/csrc/xqa/mla_sm120.cu index 088f601015..2396fb8c5b 100644 --- a/csrc/xqa/mla_sm120.cu +++ b/csrc/xqa/mla_sm120.cu @@ -64,11 +64,9 @@ inline constexpr uint32_t nbRegsForMathWarps = 232; inline constexpr bool computeRowSumFromF8 = true; struct KVTilePartLoader { -#if USE_PAGED_KV_CACHE static_assert(tokensPerPage % tokensPerTile == 0 || tokensPerTile % tokensPerPage == 0); static inline constexpr uint32_t nbPagesPerTile = tokensPerTile >= tokensPerPage ? exactDiv(tokensPerTile, tokensPerPage) : 1; -#endif static inline constexpr uint32_t const nbKHeads = 1; KVCacheList const& cacheList; @@ -78,20 +76,13 @@ struct KVTilePartLoader { CUtensorMap const& tensorMap; // if greater than 1, then we need unrolling for the loading loop. Seems 1 is fine for latency. static inline constexpr uint32_t nbPageBuffers = 1; -#if USE_PAGED_KV_CACHE uint32_t const nbPages; // for bound check Vec pageBuffers[nbPageBuffers]; uint32_t idxTileRef = ~0U; // idxTile used to load the pages -#endif uint32_t const baseOffset; __device__ KVTilePartLoader(KVCacheList const& cacheList, uint32_t idxReq, - CUtensorMap const& tensorMap -#if USE_PAGED_KV_CACHE - , - uint32_t nbPages -#endif - ); + CUtensorMap const& tensorMap, uint32_t nbPages); // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache template __device__ void loadData(Array2D& dst, @@ -102,30 +93,13 @@ struct KVTilePartLoader { }; __device__ inline KVTilePartLoader::KVTilePartLoader(KVCacheList const& cacheList, - uint32_t idxReq, CUtensorMap const& tensorMap -#if USE_PAGED_KV_CACHE - , - uint32_t nbPages -#endif - ) + uint32_t idxReq, CUtensorMap const& tensorMap, + uint32_t nbPages) : cacheList{cacheList}, idxReq{idxReq}, - tensorMap{tensorMap} -#if USE_PAGED_KV_CACHE - , - nbPages{nbPages} -#if PAGED_KV_CACHE_LAYOUT == 1 - , - baseOffset{idxReq * cacheList.maxNbPagesPerSeq} -#else - , - baseOffset{((idxReq * beamWidth) * 2) * cacheList.maxNbPagesPerSeq} -#endif -#else - , - baseOffset{(idxReq * beamWidth) * 2} -#endif -{ + tensorMap{tensorMap}, + nbPages{nbPages}, + baseOffset{idxReq * cacheList.maxNbPagesPerSeq} { #pragma unroll for (auto& pageBuffer : pageBuffers) { pageBuffer.fill(kBAD_PAGE_INDEX); @@ -138,45 +112,27 @@ __device__ inline void KVTilePartLoader::loadData( Array2D& dst, uint32_t idxTile, uint32_t idxElemBeg, CtaBarrier& bar, uint32_t idxPageBuf) { static_assert(nbTokens == tokensPerTile); -#if USE_PAGED_KV_CACHE assert(idxTile == idxTileRef); auto const& pages = pageBuffers[idxPageBuf]; if constexpr (nbTokens < tokensPerPage) { assert(nbPagesPerTile == 1); uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); if (warpElectSync()) { -#if PAGED_KV_CACHE_LAYOUT == 1 tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, idxHeadGrp, offset, (uint32_t)pages[0]}, bar); -#else - tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, offset, idxHeadGrp, (uint32_t)pages[0]}, - bar); -#endif } } else { #pragma unroll for (uint32_t i = 0; i < nbPagesPerTile; i++) { if (warpElectSync()) { -#if PAGED_KV_CACHE_LAYOUT == 1 tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, DimsLE<4>{idxElemBeg, idxHeadGrp, 0, (uint32_t)pages[i]}, bar); -#else - tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, - DimsLE<4>{idxElemBeg, 0, idxHeadGrp, (uint32_t)pages[i]}, bar); -#endif } } } -#else - if (warpElectSync()) { - tma::loadAsync(&dst, tensorMap, - DimsLE<4>{idxElemBeg, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); - } -#endif } __device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile, uint32_t idxPageBuf) { -#if USE_PAGED_KV_CACHE uint32_t const idxPageBeg = tokensPerTile >= tokensPerPage ? nbPagesPerTile * idxTile : idxTile / exactDiv(tokensPerPage, tokensPerTile); @@ -188,7 +144,6 @@ __device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile, uint32_t id idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX; } idxTileRef = idxTile; -#endif } using Mat16x32 = Vec; @@ -860,12 +815,7 @@ struct Producer { }; __device__ inline void Producer::loadK() { - KVTilePartLoader loader{args.cacheList, idxReq, args.tensorMapK -#if USE_PAGED_KV_CACHE - , - divUp(seqLen, tokensPerPage) -#endif - }; + KVTilePartLoader loader{args.cacheList, idxReq, args.tensorMapK, divUp(seqLen, tokensPerPage)}; #pragma unroll 1 for (uint32_t iter = 0; true; iter++) { @@ -1340,12 +1290,7 @@ __device__ inline void Consumer::loadX() { } __device__ inline void Consumer::loadV() { - KVTilePartLoader loader(args.cacheList, idxReq, args.tensorMapV -#if USE_PAGED_KV_CACHE - , - divUp(seqLen, tokensPerPage) -#endif - ); + KVTilePartLoader loader(args.cacheList, idxReq, args.tensorMapV, divUp(seqLen, tokensPerPage)); for (uint32_t iter = 0; true; iter++) { uint32_t const idxTile = iterToTile(iter); if (idxTile >= nbTiles()) { @@ -1707,24 +1652,15 @@ void launchMLA( cudaDeviceProp const& prop, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed float qScale, OutputHead* output, InputHead const* q, -#if USE_PAGED_KV_CACHE -#if PAGED_KV_CACHE_LAYOUT == 1 - GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout - GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout -#else - GMemCacheHead* pool, // global pool of pages -#endif - KVCachePageIndex const* - kvCachePageList, // device pointer. shape: - // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or - // [batchSize][maxNbPagesPerSeq] (Layout 1) -#else - GMemKVCacheHead* kvCacheData, -#endif + GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout + GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout + KVCachePageIndex const* kvCachePageList, // device pointer. shape: + // [batchSize][maxNbPagesPerSeq] (Layout 1) uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. // Used only for int8/fp8 KV cache. - uint32_t* semaphores, void* scratch, cudaStream_t stream) { + uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, + uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) { #if IS_MLA static_assert( SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0, @@ -1762,15 +1698,10 @@ void launchMLA( // nbInputSeqSplit dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize}; dim3 const dimCta{warp_size * 4 * 3, 1, 1}; - auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); -#if USE_PAGED_KV_CACHE + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); -#if PAGED_KV_CACHE_LAYOUT == 1 KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; -#else - KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; -#endif auto const dtype = [] { if (std::is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; @@ -1784,17 +1715,12 @@ void launchMLA( auto const tensorMapQ = makeTensorMapForQ(q, dtype, validElemsPerHead, headGrpSize * inputSeqLen * batchSize, partElemsK); -#if PAGED_KV_CACHE_LAYOUT == 1 auto const tensorMapK = makeTensorMapForPagedKVCache( - kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile); + kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile, + kv_stride_page, kv_stride_token, kv_stride_head); auto const tensorMapV = makeTensorMapForPagedKVCache( - vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile); -#else - auto const tensorMapK = makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, - tokensPerPage, partElemsK, tokensPerTile); - auto const tensorMapV = makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, - tokensPerPage, partElemsV, tokensPerTile); -#endif + vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile, + kv_stride_page, kv_stride_token, kv_stride_head); uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z; auto const cgaXBuf = static_cast*>(scratch); @@ -1848,20 +1774,15 @@ void launchMLAFlashInfer( uint32_t multiProcessorCount, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed float qScale, OutputHead* output, InputHead const* q, -#if PAGED_KV_CACHE_LAYOUT == 1 - GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout - GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout -#else - GMemCacheHead* pool, // global pool of pages -#endif - KVCachePageIndex const* - kvCachePageList, // device pointer. shape: - // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or - // [batchSize][maxNbPagesPerSeq] (Layout 1) + GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout + GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout + KVCachePageIndex const* kvCachePageList, // device pointer. shape: + // [batchSize][maxNbPagesPerSeq] (Layout 1) uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. // Used only for int8/fp8 KV cache. - uint32_t* semaphores, void* scratch, cudaStream_t stream) { + uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, + uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) { #if IS_MLA static_assert( SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0, @@ -1885,15 +1806,10 @@ void launchMLAFlashInfer( // nbInputSeqSplit dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize}; dim3 const dimCta{warp_size * 4 * 3, 1, 1}; - auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); -#if USE_PAGED_KV_CACHE + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); -#if PAGED_KV_CACHE_LAYOUT == 1 KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; -#else - KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; -#endif auto const dtype = [] { if (std::is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; @@ -1907,17 +1823,12 @@ void launchMLAFlashInfer( auto const tensorMapQ = makeTensorMapForQ(q, dtype, validElemsPerHead, headGrpSize * inputSeqLen * batchSize, partElemsK); -#if PAGED_KV_CACHE_LAYOUT == 1 auto const tensorMapK = makeTensorMapForPagedKVCache( - kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile); + kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile, + kv_stride_page, kv_stride_token, kv_stride_head); auto const tensorMapV = makeTensorMapForPagedKVCache( - vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile); -#else - auto const tensorMapK = makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, - tokensPerPage, partElemsK, tokensPerTile); - auto const tensorMapV = makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, - tokensPerPage, partElemsV, tokensPerTile); -#endif + vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile, + kv_stride_page, kv_stride_token, kv_stride_head); uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z; auto const cgaXBuf = static_cast*>(scratch); @@ -1925,36 +1836,6 @@ void launchMLAFlashInfer( cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, tensorMapQ, tensorMapK, tensorMapV, qScale, output, cacheList, batchSize, kvCacheScale, cgaXBuf, semaphores, partialResults); -#else - KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; - static_assert(!usePagedKVCache); - assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); - auto const tensorMap = makeTensorMapForContiguousKVCache( - kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, - batchSize, gemm0CtaTileNbTokens); - cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, -#if SLIDING_WINDOW - slidingWinSize, -#endif - qScale, output, -#if LOW_PREC_OUTPUT - rcpOutScale, -#endif -#if USE_INPUT_KV - qkv, -#if ROPE_STYLE != 0 - ropeCosSin, -#endif -#else - q, -#endif - cacheList, -#if USE_BEAM_SEARCH - beamSearchParams, -#endif - batchSize, kvCacheScale, tensorMap, semaphores, scratch); -#endif checkCuda(err); #endif } -#endif diff --git a/csrc/xqa/tensorMap.cpp b/csrc/xqa/tensorMap.cpp index e79272b018..3e76635308 100644 --- a/csrc/xqa/tensorMap.cpp +++ b/csrc/xqa/tensorMap.cpp @@ -75,28 +75,19 @@ CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataT CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t nbKHeads, uint32_t tokensPerPage, uint32_t partElems, - uint32_t nbTokensPerTile) { + uint32_t nbTokensPerTile, uint64_t stride_page, + uint64_t stride_token, uint64_t stride_head) { CUtensorMap tensorMap{}; uint32_t elemBytes = getElemBytes(dataType); -// VLLM Layout -#if PAGED_KV_CACHE_LAYOUT == 1 + // VLLM Layout uint64_t const globalDims[] = {headElems, nbKHeads, tokensPerPage, 1U << 31}; uint32_t const headBytes = elemBytes * headElems; - uint64_t const globalStrides[] = {headBytes, headBytes * nbKHeads, - headBytes * nbKHeads * tokensPerPage}; + // Use provided strides (in elements) and convert to bytes + uint64_t const globalStrides[] = {stride_head * elemBytes, stride_token * elemBytes, + stride_page * elemBytes}; uint32_t const partBytes = partElems * elemBytes; uint32_t const boxDims[] = {partElems, 1, mha::min(tokensPerPage, nbTokensPerTile), 1}; uint32_t const elemStrides[] = {1, 1, 1, 1}; - // XQA Original Layout -#else - uint64_t const globalDims[] = {headElems, tokensPerPage, nbKHeads, 1U << 31}; - uint32_t const headBytes = elemBytes * headElems; - uint64_t const globalStrides[] = {headBytes, headBytes * tokensPerPage, - headBytes * tokensPerPage * nbKHeads}; - uint32_t const partBytes = partElems * elemBytes; - uint32_t const boxDims[] = {partElems, mha::min(tokensPerPage, nbTokensPerTile), 1, 1}; - uint32_t const elemStrides[] = {1, 1, 1, 1}; -#endif auto const swizzle = [&] { switch (partBytes) { diff --git a/csrc/xqa/tensorMap.h b/csrc/xqa/tensorMap.h index d0b2c76b96..aae90c5466 100644 --- a/csrc/xqa/tensorMap.h +++ b/csrc/xqa/tensorMap.h @@ -13,4 +13,5 @@ CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataT CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t nbKHeads, uint32_t tokensPerPage, uint32_t partElems, - uint32_t nbTokensPerTile); + uint32_t nbTokensPerTile, uint64_t stride_page, + uint64_t stride_token, uint64_t stride_head); diff --git a/csrc/xqa/utils.cuh b/csrc/xqa/utils.cuh index ff65762caf..f96d83f5f5 100644 --- a/csrc/xqa/utils.cuh +++ b/csrc/xqa/utils.cuh @@ -42,7 +42,7 @@ inline constexpr int32_t kBAD_PAGE_INDEX = -1; __constant__ constexpr float kE4M3_MAX = 448.F; #ifdef __CUDA_ARCH__ -#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 +#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 || __CUDA_ARCH__ == 1210 constexpr uint32_t kMAX_SMEM_SIZE = (99u << 10); #elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 constexpr uint32_t kMAX_SMEM_SIZE = (163u << 10); diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index 089a118541..bbe314b7e3 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -21,30 +21,28 @@ using tvm::ffi::Optional; #if MLA_WRAPPER void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q, -#if PAGED_KV_CACHE_LAYOUT == 1 - TensorView kCacheVLLM, TensorView vCacheVLLM, -#else - TensorView pool, -#endif - TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, - int64_t batchSize, TensorView kvCacheScale, TensorView semaphores, - TensorView scratch) { + TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, + int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, + TensorView kvCacheScale, TensorView semaphores, TensorView scratch, + bool enable_pdl) { auto stream = get_stream(output.device()); + // Extract strides from TensorView (in elements, not bytes) + uint64_t kv_stride_page = kCacheVLLM.stride(0); + uint64_t kv_stride_token = kCacheVLLM.stride(-2); + uint64_t kv_stride_head = kCacheVLLM.stride(-3); + launchMLAFlashInfer(multiProcessorCount, 1, qScale, reinterpret_cast(output.data_ptr()), reinterpret_cast(q.data_ptr()), -#if PAGED_KV_CACHE_LAYOUT == 1 reinterpret_cast(kCacheVLLM.data_ptr()), reinterpret_cast(vCacheVLLM.data_ptr()), -#else - reinterpret_cast(pool.data_ptr()), -#endif reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, reinterpret_cast(kvCacheScale.data_ptr()), reinterpret_cast(semaphores.data_ptr()), - reinterpret_cast(scratch.data_ptr()), stream); + reinterpret_cast(scratch.data_ptr()), enable_pdl, kv_stride_page, + kv_stride_token, kv_stride_head, stream); } #else @@ -53,36 +51,32 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK #if LOW_PREC_OUTPUT TensorView rcpOutScale, #endif - TensorView q, Optional attentionSinks, -#if PAGED_KV_CACHE_LAYOUT == 1 - TensorView kCacheVLLM, TensorView vCacheVLLM, -#else - TensorView pool, -#endif - TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, - int64_t batchSize, TensorView kvCacheScale, + TensorView q, Optional attentionSinks, TensorView kCacheVLLM, + TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, + TensorView seqLen, int64_t batchSize, TensorView kvCacheScale, #if SPEC_DEC int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, #endif - TensorView semaphores, TensorView scratch) { + TensorView semaphores, TensorView scratch, bool enable_pdl) { auto stream = get_stream(output.device()); float const* attentionSinksPtr = attentionSinks.has_value() ? reinterpret_cast(attentionSinks.value().data_ptr()) : nullptr; auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer; + // Extract strides from TensorView (in elements, not bytes) + uint64_t kv_stride_page = kCacheVLLM.stride(0); + uint64_t kv_stride_token = kCacheVLLM.stride(-3); + uint64_t kv_stride_head = kCacheVLLM.stride(-2); + mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, reinterpret_cast(output.data_ptr()), #if LOW_PREC_OUTPUT reinterpret_cast(rcpOutScale.data_ptr()), #endif reinterpret_cast(q.data_ptr()), attentionSinksPtr, -#if PAGED_KV_CACHE_LAYOUT == 1 reinterpret_cast(kCacheVLLM.data_ptr()), reinterpret_cast(vCacheVLLM.data_ptr()), -#else - reinterpret_cast(pool.data_ptr()), -#endif reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, reinterpret_cast(kvCacheScale.data_ptr()), @@ -91,6 +85,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK reinterpret_cast(mask.data_ptr()), #endif reinterpret_cast(semaphores.data_ptr()), - reinterpret_cast(scratch.data_ptr()), stream); + reinterpret_cast(scratch.data_ptr()), enable_pdl, kv_stride_page, kv_stride_token, + kv_stride_head, stream); } #endif diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 45bc2c58ad..d418e5cd90 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -21,6 +21,7 @@ import torch +from .xqa import xqa from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .jit import ( gen_batch_decode_mla_module, @@ -983,7 +984,6 @@ def plan( else: kv_lens_arr_host = seq_lens.cpu() if self._backend == "trtllm-gen": - assert self._kv_layout == "HND" assert logits_soft_cap == 0.0 self._max_kv_len = max(kv_lens_arr_host).item() self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( @@ -1226,6 +1226,7 @@ def run( if enable_pdl is None: enable_pdl = device_support_pdl(q.device) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) + if self._kv_layout == "NHD": page_size = k_cache.shape[1] else: @@ -1234,6 +1235,12 @@ def run( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) + # Convert NHD layout to HND for trtllm-gen backend + if self._backend == "trtllm-gen" and self._kv_layout == "NHD": + # For NHD: [..., N, H, D] -> HND: [..., H, N, D] + k_cache = k_cache.transpose(-3, -2) + v_cache = v_cache.transpose(-3, -2) + pos_encoding_mode = self._pos_encoding_mode window_left = self._window_left if window_left is None else window_left if self._backend != "trtllm-gen": @@ -2066,7 +2073,9 @@ def trtllm_batch_decode_with_kv_cache( o_sf_scale: Optional[float] = None, o_sf_vec_size: Optional[int] = None, sinks: Optional[List[torch.Tensor]] = None, - enable_pdl: bool = None, + kv_layout: str = "HND", + enable_pdl: Optional[bool] = None, + backend: str = "auto", q_len_per_req: Optional[int] = 1, ) -> Union[torch.Tensor, FP4Tensor]: """ @@ -2076,8 +2085,11 @@ def trtllm_batch_decode_with_kv_cache( query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] - If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] + If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``, + or [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``. + If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``, + or [num_pages, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``. + The first tensor is the key cache, and the second tensor is the value cache. workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace @@ -2116,9 +2128,19 @@ def trtllm_batch_decode_with_kv_cache( sinks : Optional[List[torch.Tensor]] = None additional value per head in the denominator of the softmax. - enable_pdl : bool + kv_layout : str = "HND" + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + Defaults to ``HND``. + + enable_pdl : Optional[bool] = None Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization - Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode. + When set to ``None``, the backend will be chosen based on the device architecture and kernel availability. + + backend : str = "auto" + The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. + When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. + For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. + For sm_90 (hopper architecture) and sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. Returns ------- @@ -2140,80 +2162,255 @@ def trtllm_batch_decode_with_kv_cache( # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1) - run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode - sm_count = get_device_sm_count(query.device) - - if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): - assert query.dtype == torch.float8_e4m3fn, ( - "query must be fp8 when out_dtype is nvfp4." + if backend == "auto": + backend = ( + "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" ) - assert o_sf_scale is not None - assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported" - o_sf_vec_size = o_sf_vec_size or 16 - fp4_out_shape = query.shape[:-1] + (ceil_div(query.shape[-1], 2),) + if backend == "xqa": + # xqa backend doesn't support nvfp4 output + if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): + raise ValueError("xqa backend does not support nvfp4 output") + if o_sf_scale is not None or o_sf_vec_size is not None: + raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size") - if isinstance(out, FP4Tensor): - fp4_out_scale_shape = ( - out.scale.shape[0], - round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), + # Handle out and out_dtype + if out_dtype is None: + out_dtype = out.dtype if out is not None else query.dtype + if out is None: + out = torch.empty_like(query, dtype=out_dtype) + + # Call xqa_batch_decode_with_kv_cache + return xqa_batch_decode_with_kv_cache( + query=query, + kv_cache=(k_cache, v_cache), + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + window_left=window_left, + out=out, + sinks=sinks, + kv_layout=kv_layout, + enable_pdl=enable_pdl, + q_len_per_req=q_len_per_req, + ) + elif backend == "trtllm-gen": + # Convert NHD layout to HND if necessary (transpose only changes stride, not data) + if kv_layout == "NHD": + # For NHD: [..., N, H, D] -> HND: [..., H, N, D] + k_cache = k_cache.transpose(-3, -2) + v_cache = v_cache.transpose(-3, -2) + + run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode + sm_count = get_device_sm_count(query.device) + + if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): + assert query.dtype == torch.float8_e4m3fn, ( + "query must be fp8 when out_dtype is nvfp4." ) - out_scale_factor = out.scale - o_sf_start_index = out.scale_start_index - out = out.data - # out_dtype may be None - out_dtype = out_dtype or "nvfp4" - elif out is None: - fp4_out_scale_shape = ( - round_up(query.shape[0], 128), - round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), + assert o_sf_scale is not None + assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported" + o_sf_vec_size = o_sf_vec_size or 16 + + fp4_out_shape = query.shape[:-1] + (ceil_div(query.shape[-1], 2),) + + if isinstance(out, FP4Tensor): + fp4_out_scale_shape = ( + out.scale.shape[0], + round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), + ) + out_scale_factor = out.scale + o_sf_start_index = out.scale_start_index + out = out.data + # out_dtype may be None + out_dtype = out_dtype or "nvfp4" + elif out is None: + fp4_out_scale_shape = ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), + ) + out_scale_factor = torch.empty( + fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device + ) + o_sf_start_index = 0 + out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device) + else: + raise ValueError(f"Invalid out: {out}") + + assert out_dtype == "nvfp4" + assert isinstance(out, torch.Tensor) + + # Use uint8 as the container dtype to compliant with next fp4 gemm. + check_shape_dtype_device( + out, fp4_out_shape, torch.uint8, query.device, "out" ) - out_scale_factor = torch.empty( - fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device + + check_shape_dtype_device( + out_scale_factor, + fp4_out_scale_shape, + torch.float8_e4m3fn, + query.device, + "out_scale_factor", ) + + # Check o_sf_start_index is valid + if ( + o_sf_start_index < 0 + or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0] + ): + raise ValueError( + f"o_sf_start_index is out of the valid range of out_scale_factor. " + f"o_sf_start_index={o_sf_start_index}, out.shape[0]={out.shape[0]}, " + f"out_scale_factor.shape[0]={out_scale_factor.shape[0]}" + ) + + elif isinstance(out_dtype, torch.dtype) or out_dtype is None: + assert o_sf_scale is None + assert o_sf_vec_size is None + out_scale_factor = None o_sf_start_index = 0 - out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device) + if out_dtype is None: + out_dtype = out.dtype if out is not None else query.dtype + out = out if out is not None else torch.empty_like(query, dtype=out_dtype) + if out_dtype not in (query.dtype, torch.float16, torch.bfloat16): + raise ValueError(f"Unsupported out_dtype: {out_dtype}") + check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out") else: - raise ValueError(f"Invalid out: {out}") - - assert out_dtype == "nvfp4" - assert isinstance(out, torch.Tensor) + raise ValueError(f"Invalid out_dtype: {out_dtype}") - # Use uint8 as the container dtype to compliant with next fp4 gemm. - check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out") + bmm1_scale = ( + bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale + ) + bmm2_scale = ( + bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale + ) - check_shape_dtype_device( + run_func( + out, out_scale_factor, - fp4_out_scale_shape, - torch.float8_e4m3fn, - query.device, - "out_scale_factor", + query.view( + query.size(0) // q_len_per_req, + q_len_per_req, + query.size(1), + query.size(2), + ), + k_cache, + v_cache, + workspace_buffer, + block_tables, + seq_lens, + max_seq_len, + bmm1_scale, + bmm2_scale, + o_sf_scale or -1.0, + o_sf_vec_size or -1, + o_sf_start_index, + window_left, + sm_count, + enable_pdl, + workspace_buffer.numel() * workspace_buffer.element_size(), + sinks, ) - # Check o_sf_start_index is valid - if ( - o_sf_start_index < 0 - or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0] - ): - raise ValueError( - f"o_sf_start_index is out of the valid range of out_scale_factor. " - f"o_sf_start_index={o_sf_start_index}, out.shape[0]={out.shape[0]}, " - f"out_scale_factor.shape[0]={out_scale_factor.shape[0]}" - ) + return ( + out + if out_dtype != "nvfp4" + else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) + ) + else: + raise KeyError(f"Backend {backend} not supported") - elif isinstance(out_dtype, torch.dtype) or out_dtype is None: - assert o_sf_scale is None - assert o_sf_vec_size is None - out_scale_factor = None - o_sf_start_index = 0 - if out_dtype is None: - out_dtype = out.dtype if out is not None else query.dtype - out = out if out is not None else torch.empty_like(query, dtype=out_dtype) - if out_dtype not in (query.dtype, torch.float16, torch.bfloat16): - raise ValueError(f"Unsupported out_dtype: {out_dtype}") - check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out") + +# xqa uses NHD layout +def xqa_batch_decode_with_kv_cache( + query: torch.Tensor, + kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + workspace_buffer: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + bmm1_scale: float, + bmm2_scale: float, + window_left: int = -1, + out: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + kv_layout: str = "NHD", + enable_pdl: bool = None, + q_len_per_req: Optional[int] = 1, +) -> torch.Tensor: + """ + Parameters + ---------- + query : torch.Tensor + query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request + + kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``, + or [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``. + If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``, + or [num_pages, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``. + + workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. + workspace + + block_tables : torch.Tensor + page_table of kv cache, [batch_size, num_pages] + + seq_lens : torch.Tensor + A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` + + max_seq_len : int + max sequence length for kv_cache + + bmm1_scale : float + fused scale for bmm1 input. + + bmm2_scale : float + fused scale for bmm2 input. + + window_left : int = -1 + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. + + out : Optional[torch.Tensor] = None + output tensor, if not provided, will be allocated with ``query.dtype``. + + sinks : Optional[torch.Tensor] = None + additional value per head in the denominator of the softmax. + + kv_layout : str + The layout of the kv cache. Can be either ``NHD`` or ``HND``. Defaults to ``NHD``. + + enable_pdl : bool + Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization + Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode. + + Returns + ------- + out : torch.Tensor + output torch.Tensor. + """ + enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl + + assert q_len_per_req == 1, "xqa not support speculative decoding yet" + + if isinstance(kv_cache, tuple): + k_cache, v_cache = kv_cache else: - raise ValueError(f"Invalid out_dtype: {out_dtype}") + if kv_cache.shape[1] == 1: + k_cache, v_cache = kv_cache, kv_cache + else: + assert kv_cache.shape[1] == 2, ( + "When kv_cache is a single tensor, the second dimension must be 1 or 2" + ) + # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) + # it doesn't change underlying storage + k_cache, v_cache = kv_cache.unbind(dim=1) + + sm_count = get_device_sm_count(query.device) bmm1_scale = ( bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale @@ -2222,35 +2419,58 @@ def trtllm_batch_decode_with_kv_cache( bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale ) - run_func( - out, - out_scale_factor, - query.view( - query.size(0) // q_len_per_req, q_len_per_req, query.size(1), query.size(2) - ), + # Extract shape parameters based on layout + if kv_layout == "NHD": + # NHD: [num_pages, page_size, num_kv_heads, head_dim] + page_size = k_cache.shape[1] + num_kv_heads = k_cache.shape[2] + head_dim = k_cache.shape[3] + else: # HND + # HND: [num_pages, num_kv_heads, page_size, head_dim] + num_kv_heads = k_cache.shape[1] + page_size = k_cache.shape[2] + head_dim = k_cache.shape[3] + + workspace_u8 = workspace_buffer.view(torch.uint8) + semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore + scratch = workspace_u8[8 * 1024 * 1024 :] + kv_scale_value = bmm2_scale + q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5) + + query_new = query.unsqueeze(1).contiguous() + seq_lens_new = seq_lens.unsqueeze(1).contiguous() + sinks_new = ( + sinks.reshape(num_kv_heads, -1).contiguous() if sinks is not None else None + ) + + # Ensure 4D output for xqa + if out is None: + out = torch.empty_like(query) + out_4d = out.unsqueeze(1) + + xqa( + query_new, k_cache, v_cache, - workspace_buffer, block_tables, - seq_lens, - max_seq_len, - bmm1_scale, - bmm2_scale, - o_sf_scale or -1.0, - o_sf_vec_size or -1, - o_sf_start_index, - window_left, - sm_count, - enable_pdl, - workspace_buffer.numel() * workspace_buffer.element_size(), - sinks, + seq_lens_new, + out_4d, + scratch, + semaphore, + num_kv_heads, + page_size, + sinks=sinks_new, + q_scale=q_scale_value, + kv_scale=torch.tensor( + [kv_scale_value], dtype=torch.float32, device=query.device + ), + sliding_win_size=window_left + 1 if window_left >= 0 else 0, + kv_layout=kv_layout, + sm_count=sm_count, + enable_pdl=enable_pdl, ) - return ( - out - if out_dtype != "nvfp4" - else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) - ) + return out def _check_trtllm_gen_mla_shape( @@ -2410,6 +2630,7 @@ def trtllm_batch_decode_with_kv_cache_mla( workspace_buffer.numel() * workspace_buffer.element_size(), sinks, ) + return out diff --git a/flashinfer/jit/xqa.py b/flashinfer/jit/xqa.py index 86fa3f7895..5768236c73 100644 --- a/flashinfer/jit/xqa.py +++ b/flashinfer/jit/xqa.py @@ -25,8 +25,6 @@ xqa_nvcc_flags = [ "-DNDEBUG=1", - "-DUSE_PAGED_KV_CACHE=1", - "-DPAGED_KV_CACHE_LAYOUT=1", "-DBEAM_WIDTH=1", "-DUSE_INPUT_KV=0", "-DUSE_CUSTOM_BARRIER=1", @@ -105,7 +103,6 @@ def gen_xqa_module( + flag_sliding_window + flag_mla_wrapper, extra_ldflags=["-lcuda"], # Add CUDA Driver API library - extra_cflags=["-DPAGED_KV_CACHE_LAYOUT=1"], ) @@ -164,5 +161,4 @@ def gen_xqa_module_mla( + flag_sliding_window + flag_mla_wrapper, extra_ldflags=["-lcuda"], # Add CUDA Driver API library - extra_cflags=["-DPAGED_KV_CACHE_LAYOUT=1"], ) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 7399bd4268..49abe60897 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1857,7 +1857,6 @@ def plan( self._block_tables = block_tables if self._backend == "trtllm-gen": - assert self._kv_layout == "HND" assert logits_soft_cap == 0.0 if self._block_tables is None: blocks_per_seq = [ @@ -2041,6 +2040,7 @@ def run( _check_cached_qkv_data_type( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) + stride_block = k_cache.stride(0) if self._kv_layout == "NHD": page_size = k_cache.shape[1] @@ -2088,6 +2088,12 @@ def run( out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out" ) + # Convert NHD layout to HND for trtllm-gen backend + if self._backend == "trtllm-gen" and self._kv_layout == "NHD": + # For NHD: [..., N, H, D] -> HND: [..., H, N, D] + k_cache = k_cache.transpose(-3, -2) + v_cache = v_cache.transpose(-3, -2) + if self._custom_mask_buf is not None: mask_mode = MaskMode.CUSTOM.value else: @@ -3329,6 +3335,7 @@ def trtllm_batch_context_with_kv_cache( out_dtype: Optional[Union[torch.dtype, str]] = None, o_sf_scale: Optional[float] = None, o_sf_vec_size: Optional[int] = None, + kv_layout: str = "HND", enable_pdl: Optional[bool] = None, sinks: Optional[List[torch.Tensor]] = None, ) -> Union[torch.Tensor, FP4Tensor]: @@ -3338,8 +3345,11 @@ def trtllm_batch_context_with_kv_cache( query : torch.Tensor query tensor with shape [num_tokens, num_heads, head_dim] kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] - If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] + If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is "HND", + or [num_pages, 1 or 2, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is "NHD". + If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is "HND", + or [num_pages, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is "NHD". + The first tensor is the key cache, the second tensor is the value cache. workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace block_tables : torch.Tensor @@ -3371,6 +3381,11 @@ def trtllm_batch_context_with_kv_cache( scale for nvfp4 output tensor scale factor. o_sf_vec_size : Optional[int] = None vector size for nvfp4 output tensor scale factor. + enable_pdl : Optional[bool] = None + Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization + Defaults to ``None``, which means it will be enabled if the device supports PDL. + kv_layout : str = "HND" + Layout of kv-cache, can be "HND" or "NHD", default is "HND". sinks : Optional[List[torch.Tensor]] = None additional value per head in the denominator of the softmax. @@ -3396,6 +3411,12 @@ def trtllm_batch_context_with_kv_cache( # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1) + # Convert NHD layout to HND if necessary (transpose only changes stride, not data) + if kv_layout == "NHD": + # For NHD: [..., N, H, D] -> HND: [..., H, N, D] + k_cache = k_cache.transpose(-3, -2) + v_cache = v_cache.transpose(-3, -2) + run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_context sm_count = get_device_sm_count(query.device) diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index 30997c2d10..fba5045d74 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -26,6 +26,7 @@ register_custom_op, register_fake_op, get_compute_capability, + device_support_pdl, ) @@ -69,6 +70,7 @@ def xqa( kv_scale: torch.Tensor, semaphores: torch.Tensor, workspace_buffer: torch.Tensor, + enable_pdl: bool, ) -> None: module.xqa_wrapper( run_sm90_fp8_mha, @@ -88,6 +90,7 @@ def xqa( kv_scale, semaphores, workspace_buffer, + enable_pdl, ) @register_fake_op( @@ -134,7 +137,9 @@ def xqa( q_scale: float = 1.0, kv_scale: Optional[torch.Tensor] = None, sliding_win_size: int = 0, + kv_layout: str = "NHD", sm_count: Optional[int] = None, + enable_pdl: Optional[bool] = None, ) -> None: r"""Apply attention with paged KV cache using XQA kernel. Parameters @@ -144,11 +149,13 @@ def xqa( Data type should be torch.float16 or torch.bfloat16. Now only beam_width 1 is supported. k_cache: torch.Tensor - Paged K cache tensor with shape ``[total_num_cache_heads, head_dim]``. + Paged K cache tensor with shape ``[num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. Should be the same data type as v_cache. v_cache: torch.Tensor - Paged V cache tensor with shape ``[total_num_cache_heads, head_dim]``. + Paged V cache tensor with shape ``[num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. Should be the same data type as k_cache. page_table : torch.Tensor @@ -183,9 +190,15 @@ def xqa( If None, defaults to 1.0. sliding_win_size : int, default=0 Sliding window size for attention. If 0, no sliding window is used. + kv_layout : str, default="NHD" + The layout of the KV cache. Can be either ``NHD`` or ``HND``. sm_count : Optional[int], default=None Number of streaming multiprocessors to use. If None, will be inferred from the device. + enable_pdl : Optional[bool], default=None + Whether to enable PDL (Persistent Data Loader) optimization. + If None, will be set to True if hardware supports it. + Note ---- The function automatically infers several parameters from tensor shapes: @@ -204,6 +217,8 @@ def xqa( if kv_scale is None: kv_scale = torch.ones(1, dtype=torch.float32, device=q.device) + enable_pdl = enable_pdl if enable_pdl is not None else device_support_pdl(q.device) + # Infer parameters from tensors batch_size = q.shape[0] num_q_heads = q.shape[2] @@ -221,6 +236,12 @@ def xqa( assert k_cache.dtype == v_cache.dtype, "K and V cache must have the same dtype" + # Convert HND layout to NHD if necessary (transpose only changes stride, not data) + if kv_layout == "HND": + # For HND: [..., H, N, D] -> NHD: [..., N, H, D] + k_cache = k_cache.transpose(-3, -2) + v_cache = v_cache.transpose(-3, -2) + if ( k_cache.dtype == torch.float8_e4m3fn and get_compute_capability(torch.device(device="cuda"))[0] == 9 @@ -258,6 +279,7 @@ def xqa( kv_scale, semaphores, workspace_buffer, + enable_pdl, ) @@ -297,6 +319,7 @@ def xqa_mla( kv_scale: torch.Tensor, semaphores: torch.Tensor, workspace_buffer: torch.Tensor, + enable_pdl: bool, ) -> None: module.xqa_wrapper_mla( sm_count, @@ -312,6 +335,7 @@ def xqa_mla( kv_scale, semaphores, workspace_buffer, + enable_pdl, ) @register_fake_op( @@ -331,6 +355,7 @@ def _fake_xqa_mla( kv_scale: torch.Tensor, semaphores: torch.Tensor, workspace_buffer: torch.Tensor, + enable_pdl: bool, ) -> None: pass @@ -352,6 +377,7 @@ def xqa_mla( q_scale: float = 1.0, kv_scale: Optional[torch.Tensor] = None, sm_count: Optional[int] = None, + enable_pdl: Optional[bool] = None, ) -> None: r"""Apply attention with paged KV cache using XQA MLA (Multi-Head Latent Attention) kernel. Parameters @@ -393,6 +419,10 @@ def xqa_mla( sm_count : Optional[int], default=None Number of streaming multiprocessors to use. If None, will be inferred from the device. + enable_pdl : Optional[bool], default=None + Whether to enable PDL (Persistent Data Loader) optimization. + If None, will be set to True if hardware supports it. + Note ---- The function automatically infers several parameters from tensor shapes: @@ -408,6 +438,8 @@ def xqa_mla( if kv_scale is None: kv_scale = torch.ones(1, dtype=torch.float32, device=q.device) + enable_pdl = enable_pdl if enable_pdl is not None else device_support_pdl(q.device) + # Infer parameters from tensors batch_size = q.shape[0] head_dim = q.shape[-1] @@ -446,4 +478,5 @@ def xqa_mla( kv_scale, semaphores, workspace_buffer, + enable_pdl, ) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index f14c57b1f1..dcb5e01a2a 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -91,7 +91,14 @@ def create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype): def create_kv_cache( - batch_size, seq_lens, page_size, num_kv_heads, head_dim, kv_dtype, ref_kv_dtype + batch_size, + seq_lens, + page_size, + num_kv_heads, + head_dim, + kv_dtype, + ref_kv_dtype, + kv_layout="HND", ): # Create separate K and V caches max_seq_len = torch.max(seq_lens).item() @@ -103,22 +110,43 @@ def create_kv_cache( "kv_dtype and ref_kv_dtype must be the same for non-fp8 kv_cache" ) - k_cache = torch.randn( - num_pages, - num_kv_heads, - page_size, - head_dim, - dtype=ref_kv_dtype_torch, - device=GPU_DEVICE, - ) - v_cache = torch.randn( - num_pages, - num_kv_heads, - page_size, - head_dim, - dtype=ref_kv_dtype_torch, - device=GPU_DEVICE, - ) + # Create cache with appropriate layout + if kv_layout == "HND": + # HND layout: [num_pages, num_kv_heads, page_size, head_dim] + k_cache = torch.randn( + num_pages, + num_kv_heads, + page_size, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) + v_cache = torch.randn( + num_pages, + num_kv_heads, + page_size, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) + else: # NHD layout + # NHD layout: [num_pages, page_size, num_kv_heads, head_dim] + k_cache = torch.randn( + num_pages, + page_size, + num_kv_heads, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) + v_cache = torch.randn( + num_pages, + page_size, + num_kv_heads, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) # Convert K and V separately to fp8 if needed if kv_dtype == "fp8": @@ -173,6 +201,7 @@ def flatten_paged_kv( seq_lens: torch.Tensor, page_size: int, kv_last_page_len: torch.Tensor, + kv_layout: str = "HND", ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Build flat K/V and token-level indptr from paged KV cache and page table.""" device = ref_kv_cache.device @@ -192,11 +221,20 @@ def flatten_paged_kv( page_id = int(page_table_cpu[i, j].item()) k_page = ref_kv_cache[page_id, 0] v_page = ref_kv_cache[page_id, 1] - if j == pages_i - 1: - k_page = k_page[:, :last_len_i, :] - v_page = v_page[:, :last_len_i, :] - k_list.append(einops.rearrange(k_page, "h p d -> p h d")) - v_list.append(einops.rearrange(v_page, "h p d -> p h d")) + if kv_layout == "HND": + # HND layout: [num_kv_heads, page_size, head_dim] + if j == pages_i - 1: + k_page = k_page[:, :last_len_i, :] + v_page = v_page[:, :last_len_i, :] + k_list.append(einops.rearrange(k_page, "h p d -> p h d")) + v_list.append(einops.rearrange(v_page, "h p d -> p h d")) + else: # NHD layout + # NHD layout: [page_size, num_kv_heads, head_dim] + if j == pages_i - 1: + k_page = k_page[:last_len_i, :, :] + v_page = v_page[:last_len_i, :, :] + k_list.append(einops.rearrange(k_page, "p h d -> p h d")) + v_list.append(einops.rearrange(v_page, "p h d -> p h d")) k_flat = torch.cat(k_list, dim=0) v_flat = torch.cat(v_list, dim=0) kv_indptr_tokens = torch.cat( @@ -301,7 +339,7 @@ def unpack_compare_nvfp4( return output_unpacked, output_ref -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize( "batch_size,page_size,num_kv_heads,head_grp_size", [ @@ -374,6 +412,7 @@ def test_trtllm_batch_prefill( head_dim, kv_dtype, "bf16" if q_dtype == "fp8" else q_dtype, + kv_layout, ) page_table, all_page_ids, page_per_seq = create_page_table( batch_size, seq_lens, page_size @@ -428,6 +467,7 @@ def test_trtllm_batch_prefill( seq_lens.to(GPU_DEVICE), page_size, kv_last_page_len, + kv_layout, ) sink = torch.rand(num_qo_heads, device=GPU_DEVICE, dtype=torch.float32) * 5 output_ref = sink_attention_unified( @@ -463,6 +503,7 @@ def test_trtllm_batch_prefill( out_dtype=out_dtype, o_sf_scale=o_sf_scale, o_sf_vec_size=o_sf_vec_size, + kv_layout=kv_layout, enable_pdl=enable_pdl, sinks=(sink if enable_sink else None), ) @@ -527,7 +568,7 @@ def test_trtllm_batch_prefill( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize( "batch_size,page_size,num_kv_heads,head_grp_size", [ @@ -578,6 +619,7 @@ def test_trtllm_batch_prefill_bs1( def _test_trtllm_batch_decode( + backend, kv_layout, batch_size, q_len_per_req, @@ -599,8 +641,25 @@ def _test_trtllm_batch_decode( Combinations of parameters are tested in test_trtllm_batch_decode() and test_trtllm_batch_decode_...() """ compute_capability = get_compute_capability(torch.device(device="cuda")) - if compute_capability[0] != 10: - pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + + # Check GPU architecture requirements for different backends + if backend == "trtllm-gen" and compute_capability[0] != 10: + pytest.skip("trtllm-gen backend requires SM100 and SM103 GPUs.") + if backend == "xqa" and compute_capability[0] < 9: + pytest.skip("xqa backend requires SM90+ GPUs.") + + # xqa backend doesn't support nvfp4 output + if backend == "xqa" and o_dtype == "nvfp4": + pytest.skip("xqa backend does not support nvfp4 output") + + if backend == "xqa" and q_dtype == "fp8": + pytest.skip("xqa backend only supports fp16 and bf16 query") + + # xqa backend doesn't support speculative decoding yet + if backend == "xqa" and q_len_per_req > 1: + pytest.skip( + "xqa backend does not support speculative decoding (q_len_per_req > 1) yet" + ) if o_dtype == "nvfp4" and q_len_per_req > 1: # todo(Yingyi): add support for nvfp4 with speculative decoding @@ -628,6 +687,7 @@ def _test_trtllm_batch_decode( head_dim, kv_dtype, "bf16" if q_dtype == "fp8" else q_dtype, + kv_layout, ) page_table, all_page_ids, page_per_seq = create_page_table( batch_size, seq_lens, page_size @@ -700,6 +760,7 @@ def _test_trtllm_batch_decode( seq_lens.to(GPU_DEVICE), page_size, kv_last_page_len, + kv_layout, ) sink = torch.rand(num_qo_heads, device=GPU_DEVICE, dtype=torch.float32) * 5 output_ref = sink_attention_unified( @@ -716,7 +777,7 @@ def _test_trtllm_batch_decode( kv_indptr=kv_indptr_tokens, ) - # Run trtllm-gen function call + # Run decode function call with specified backend output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( q.contiguous(), kv_cache, @@ -731,13 +792,16 @@ def _test_trtllm_batch_decode( out_dtype=out_dtype, o_sf_scale=o_sf_scale, o_sf_vec_size=o_sf_vec_size, - enable_pdl=enable_pdl, sinks=(sink if enable_sink else None), + kv_layout=kv_layout, + enable_pdl=enable_pdl, + backend=backend, q_len_per_req=q_len_per_req, ) - # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero - # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future - assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + if backend == "trtllm-gen": + # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero + # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() if o_dtype == "nvfp4": output, output_ref = unpack_compare_nvfp4( @@ -752,6 +816,10 @@ def _test_trtllm_batch_decode( else: rtol, atol = 1e-2, 1e-2 + if backend == "xqa" and kv_dtype == "fp8": + atol = 1e-1 + rtol = 1e-1 + # convert to float32 for fp8 is not supported by assert_close # relax rtol and atol for speculative decoding test if q_len_per_req > 1: @@ -771,7 +839,10 @@ def _test_trtllm_batch_decode( max_mismatched_elements=max_mismatched_elements, ) - if o_dtype != "nvfp4": # wrapper api does not support fp4 output yet. + # Only test wrapper with trtllm-gen backend + if ( + o_dtype != "nvfp4" and backend == "trtllm-gen" + ): # wrapper api does not support fp4 output yet. # test wrapper with trtllm-gen backend wrapper_trtllm_gen = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, backend="trtllm-gen" @@ -823,7 +894,8 @@ def _test_trtllm_batch_decode( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", [ @@ -865,6 +937,7 @@ def _test_trtllm_batch_decode( @pytest.mark.parametrize("max_in_kv_len", [110]) @pytest.mark.parametrize("head_dim", [128]) def test_trtllm_batch_decode( + backend, kv_layout, batch_size, q_len_per_req, @@ -882,6 +955,7 @@ def test_trtllm_batch_decode( ): # General set of tests for trtllm-gen decode _test_trtllm_batch_decode( + backend, kv_layout, batch_size, q_len_per_req, @@ -936,6 +1010,7 @@ def test_trtllm_batch_decode_bs1( # Small number of test cases for batch size 1 pytest.xfail("trtllm-gen decode gets incorrect output with bs1") _test_trtllm_batch_decode( + "trtllm-gen", kv_layout, batch_size, q_len_per_req, @@ -1001,6 +1076,7 @@ def test_trtllm_batch_decode_head_dim_256( # Small number of test cases for head_dim = 256 pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") _test_trtllm_batch_decode( + "trtllm-gen", kv_layout, batch_size, q_len_per_req, @@ -1059,6 +1135,7 @@ def test_trtllm_batch_decode_long_sequence_length( # Small number of test cases for long sequence length pytest.xfail("trtllm-gen decode gets incorrect output with Long sequence length") _test_trtllm_batch_decode( + "trtllm-gen", kv_layout, batch_size, q_len_per_req, diff --git a/tests/attention/test_xqa.py b/tests/attention/test_xqa.py index a830134fc0..4e81f72bd5 100644 --- a/tests/attention/test_xqa.py +++ b/tests/attention/test_xqa.py @@ -40,22 +40,24 @@ def __init__( nb_heads: int, idx_head: int, tokens_per_page: int = 32, + kv_layout: str = "NHD", ): self.pool = pool self.page_indices = page_indices self.nb_heads = nb_heads self.idx_head = idx_head self.tokens_per_page = tokens_per_page + self.kv_layout = kv_layout def __getitem__(self, i: int) -> torch.Tensor: page_idx = self.page_indices[i // self.tokens_per_page].to(torch.int32) - # VLLM layout (PAGED_KV_CACHE_LAYOUT=1): [page_idx][token_in_page][nb_heads][head_dim] - idx_head = ( - page_idx * self.tokens_per_page * self.nb_heads - + (i % self.tokens_per_page) * self.nb_heads - + self.idx_head - ) - return self.pool[idx_head] + token_in_page = i % self.tokens_per_page + if self.kv_layout == "NHD": + # NHD layout: [page_idx, token_in_page, idx_head, :] + return self.pool[page_idx, token_in_page, self.idx_head, :] + else: # HND + # HND layout: [page_idx, idx_head, token_in_page, :] + return self.pool[page_idx, self.idx_head, token_in_page, :] def ref_attention( @@ -167,6 +169,7 @@ def ref_attention( get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12], reason="XQA is only supported on SM90, SM100, SM120 GPUs", ) +@pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("use_sliding_window", [True, False]) @pytest.mark.parametrize("input_type", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("fp8_kv_cache", [True, False]) @@ -177,6 +180,7 @@ def ref_attention( @pytest.mark.parametrize("tokens_per_page", [16, 64]) @pytest.mark.parametrize("valid_elems_per_head", [32, 128]) @pytest.mark.parametrize("head_grp_size", [8, 16]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) def test_xqa( batch_size, nb_k_heads, @@ -188,6 +192,8 @@ def test_xqa( head_grp_size, use_attention_sinks, use_sliding_window, + enable_pdl, + kv_layout, ): set_random_seed(42) @@ -227,24 +233,48 @@ def test_xqa( max_seq_len = round_up(seq_len, tokens_per_page) nb_pages_per_seq = div_up(max_seq_len, tokens_per_page) - # Layout 1: K and V share page indices - # Total cache heads = nb_k_heads * max_seq_len * batch_size - total_nb_cache_heads = nb_k_heads * max_seq_len * batch_size + # Total number of pages needed for all sequences + total_num_pages = nb_pages_per_seq * batch_size + + # Create cache with specified layout + if kv_layout == "NHD": + # NHD layout: [num_pages, page_size, num_kv_heads, head_dim] + cache_k_heads = torch.zeros( + total_num_pages, + tokens_per_page, + nb_k_heads, + valid_elems_per_head, + dtype=input_type, + device="cuda", + ) + cache_v_heads = torch.zeros( + total_num_pages, + tokens_per_page, + nb_k_heads, + valid_elems_per_head, + dtype=input_type, + device="cuda", + ) + else: # HND layout + # HND layout: [num_pages, num_kv_heads, page_size, head_dim] + cache_k_heads = torch.zeros( + total_num_pages, + nb_k_heads, + tokens_per_page, + valid_elems_per_head, + dtype=input_type, + device="cuda", + ) + cache_v_heads = torch.zeros( + total_num_pages, + nb_k_heads, + tokens_per_page, + valid_elems_per_head, + dtype=input_type, + device="cuda", + ) - cache_k_heads = torch.zeros( - total_nb_cache_heads, - valid_elems_per_head, - dtype=input_type, - device="cuda", - ) cache_k_heads.normal_(0, 1) - - cache_v_heads = torch.zeros( - total_nb_cache_heads, - valid_elems_per_head, - dtype=input_type, - device="cuda", - ) cache_v_heads.normal_(0, 1) if fp8_kv_cache: @@ -279,18 +309,19 @@ def cache_head_at( beam_width, nb_k_heads, tokens_per_page, + kv_layout, ): - # Layout 1: K and V share page indices + # K and V share page indices page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) + token_in_page = pos % tokens_per_page - # VLLM layout: [page_idx][token_in_page][nb_heads][head_dim] - idx_head = ( - page_idx * tokens_per_page * nb_k_heads - + (pos % tokens_per_page) * nb_k_heads - + idx_kv_head - ) - - return cache_k_heads[idx_head] if is_k else cache_v_heads[idx_head] + cache = cache_k_heads if is_k else cache_v_heads + if kv_layout == "NHD": + # NHD layout: [page_idx, token_in_page, idx_kv_head, :] + return cache[page_idx, token_in_page, idx_kv_head, :] + else: # HND + # HND layout: [page_idx, idx_kv_head, token_in_page, :] + return cache[page_idx, idx_kv_head, token_in_page, :] for batch in range(batch_size): for kv in range(2): @@ -307,6 +338,7 @@ def cache_head_at( beam_width, nb_k_heads, tokens_per_page, + kv_layout, ) cache_head.fill_(0.0) @@ -340,19 +372,22 @@ def cache_head_at( q_scale=q_scale, kv_scale=kv_cache_scale, sliding_win_size=sliding_win_size, + kv_layout=kv_layout, sm_count=sm_count, + enable_pdl=enable_pdl, ) for req in range(batch_size): for b in range(beam_width): for idx_k_head in range(nb_k_heads): - # Layout 1: K and V use separate pools but share page indices + # K and V use separate pools but share page indices k_cache_seq = CacheSeq( pool=cache_k_heads, page_indices=page_list_arg[req], nb_heads=nb_k_heads, idx_head=idx_k_head, tokens_per_page=tokens_per_page, + kv_layout=kv_layout, ) v_cache_seq = CacheSeq( pool=cache_v_heads, @@ -360,6 +395,7 @@ def cache_head_at( nb_heads=nb_k_heads, idx_head=idx_k_head, tokens_per_page=tokens_per_page, + kv_layout=kv_layout, ) ref_output = ref_attention( @@ -407,6 +443,7 @@ def cache_head_at( get_compute_capability(torch.device(device="cuda"))[0] not in [12], reason="XQA mla is only supported on SM120 GPUs", ) +@pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("seq_len", [2, 15, 256, 514, 2048]) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens_per_page", [32, 64]) @@ -414,6 +451,7 @@ def test_xqa_mla( batch_size, seq_len, tokens_per_page, + enable_pdl, ): set_random_seed(42) @@ -446,12 +484,14 @@ def test_xqa_mla( max_seq_len = round_up(seq_len, tokens_per_page) nb_pages_per_seq = div_up(max_seq_len, tokens_per_page) - # Layout 1: K and V share page indices - # Total cache heads = nb_k_heads * max_seq_len * batch_size - total_nb_cache_heads = nb_k_heads * max_seq_len * batch_size + # Total number of pages needed for all sequences + total_num_pages = nb_pages_per_seq * batch_size + # NHD layout: [num_pages, page_size, num_kv_heads, head_dim] cache_k_heads = torch.zeros( - total_nb_cache_heads, + total_num_pages, + tokens_per_page, + nb_k_heads, valid_elems_per_head_qk, # K dimension is 576 dtype=torch.float32, device="cuda", @@ -459,7 +499,9 @@ def test_xqa_mla( cache_k_heads.normal_(0, 1) cache_v_heads = torch.zeros( - total_nb_cache_heads, + total_num_pages, + tokens_per_page, + nb_k_heads, valid_elems_per_head_qk, # V storage is 576 (but only 512 used) dtype=torch.float32, device="cuda", @@ -497,17 +539,13 @@ def cache_head_at( nb_k_heads, tokens_per_page, ): - # Layout 1: K and V share page indices + # K and V share page indices page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) + token_in_page = pos % tokens_per_page - # VLLM layout: [page_idx][token_in_page][nb_heads][head_dim] - idx_head = ( - page_idx * tokens_per_page * nb_k_heads - + (pos % tokens_per_page) * nb_k_heads - + idx_kv_head - ) - - return cache_k_heads[idx_head] if is_k else cache_v_heads[idx_head] + # NHD layout: [page_idx, token_in_page, idx_kv_head, :] + cache = cache_k_heads if is_k else cache_v_heads + return cache[page_idx, token_in_page, idx_kv_head, :] for batch in range(batch_size): for kv in range(2): @@ -555,12 +593,13 @@ def cache_head_at( q_scale=q_scale, kv_scale=kv_cache_scale, sm_count=sm_count, + enable_pdl=enable_pdl, ) for req in range(batch_size): for b in range(beam_width): for idx_k_head in range(nb_k_heads): - # Layout 1: K and V use separate pools but share page indices + # K and V use separate pools but share page indices k_cache_seq = CacheSeq( pool=cache_k_heads, page_indices=page_list_arg[req], diff --git a/tests/attention/test_xqa_batch_decode.py b/tests/attention/test_xqa_batch_decode.py new file mode 100644 index 0000000000..fbeac45354 --- /dev/null +++ b/tests/attention/test_xqa_batch_decode.py @@ -0,0 +1,457 @@ +import pytest +import torch +from tests.test_helpers.sink_attention_reference import sink_attention_unified + +import flashinfer +from flashinfer.utils import get_compute_capability + +DTYPE_MAP = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp8": torch.float8_e4m3fn, +} + +GPU_DEVICE = "cuda:0" + +global_workspace_buffer = None # can be empty initialized +global_xqa_workspace_buffer = None # must be zero initialized +workspace_size = 256 * 1024 * 1024 + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len): + q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32) + in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int) + in_kv_lens[-1] = max_in_kv_len + seq_lens = q_lens + in_kv_lens + return q_lens, in_kv_lens, seq_lens + + +def generate_cumsum_lens(lens): + return torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=GPU_DEVICE), + torch.cumsum(lens.to(GPU_DEVICE), dim=0, dtype=torch.int32), + ] + ) + + +def create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype): + q = torch.randn( + torch.sum(q_lens).item(), + num_qo_heads, + head_dim, + dtype=torch.bfloat16 if q_dtype == "fp8" else DTYPE_MAP[q_dtype], + device=GPU_DEVICE, + ) + if q_dtype == "fp8": + q, q_scale = to_float8(q) + # Reference implementation have functional issue or low precision with fp8, use bfloat16 and fake-quantization instead. + ref_q = q.bfloat16() * q_scale + else: + q_scale = 1.0 + ref_q = q + + return q, q_scale, ref_q + + +def create_kv_cache( + batch_size, + seq_lens, + page_size, + num_kv_heads, + head_dim, + kv_dtype, + ref_kv_dtype, + kv_layout="NHD", +): + # Create separate K and V caches with specified layout (NHD or HND) + max_seq_len = torch.max(seq_lens).item() + num_tokens = max_seq_len * batch_size + num_pages = (num_tokens + page_size - 1) // page_size + ref_kv_dtype_torch = DTYPE_MAP[ref_kv_dtype] + if kv_dtype != "fp8": + assert kv_dtype == ref_kv_dtype, ( + "kv_dtype and ref_kv_dtype must be the same for non-fp8 kv_cache" + ) + + # Create cache with specified layout + if kv_layout == "NHD": + # NHD layout: [num_pages, page_size, num_kv_heads, head_dim] + k_cache = torch.randn( + num_pages, + page_size, + num_kv_heads, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) + v_cache = torch.randn( + num_pages, + page_size, + num_kv_heads, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) + else: # HND layout + # HND layout: [num_pages, num_kv_heads, page_size, head_dim] + k_cache = torch.randn( + num_pages, + num_kv_heads, + page_size, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) + v_cache = torch.randn( + num_pages, + num_kv_heads, + page_size, + head_dim, + dtype=ref_kv_dtype_torch, + device=GPU_DEVICE, + ) + + # Convert K and V separately to fp8 if needed + if kv_dtype == "fp8": + k_cache, k_scale = to_float8(k_cache / 4.0) + v_cache, v_scale = to_float8(v_cache / 4.0) + # use high precision and fake-quantization for reference to avoid precision/functional issue + ref_kv_cache = torch.stack( + [ + k_cache.to(ref_kv_dtype_torch) * k_scale, + v_cache.to(ref_kv_dtype_torch) * v_scale, + ], + dim=1, + ) + else: + k_scale = v_scale = 1.0 + ref_kv_cache = torch.stack([k_cache, v_cache], dim=1) + # Combine K and V into interleaved format for the API + kv_cache = torch.stack([k_cache, v_cache], dim=1) + + return kv_cache, k_scale, v_scale, ref_kv_cache + + +def create_page_table(batch_size, seq_lens, page_size): + page_per_seq = (seq_lens + page_size - 1) // page_size + max_num_pages_per_seq = torch.max(page_per_seq).item() + + # Generate random but unique page IDs for all sequences + total_pages_needed = torch.sum(page_per_seq).item() + all_page_ids = torch.randperm( + total_pages_needed, dtype=torch.int32, device=GPU_DEVICE + ) + + # Generate unique page IDs for all sequences + page_tables = torch.zeros( + (batch_size, max_num_pages_per_seq), dtype=torch.int32, device=GPU_DEVICE + ) + + # Populate page tables and track page assignments + page_id = 0 + for i in range(batch_size): + num_pages_needed = page_per_seq[i] + page_tables[i, :num_pages_needed] = all_page_ids[ + page_id : page_id + num_pages_needed + ] + page_id += num_pages_needed + return page_tables, all_page_ids, page_per_seq + + +def flatten_paged_kv( + ref_kv_cache: torch.Tensor, + page_table: torch.Tensor, + seq_lens: torch.Tensor, + page_size: int, + kv_last_page_len: torch.Tensor, + kv_layout: str = "NHD", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build flat K/V and token-level indptr from paged KV cache and page table. + + Supports both NHD and HND layouts. + """ + device = ref_kv_cache.device + batch_size = int(page_table.shape[0]) + + # Move loop-control tensors to CPU to avoid GPU sync in loops + page_table_cpu = page_table.cpu() + seq_lens_cpu = seq_lens.cpu() + kv_last_page_len_cpu = kv_last_page_len.cpu() + page_per_seq = (seq_lens_cpu + page_size - 1) // page_size + k_list = [] + v_list = [] + for i in range(batch_size): + pages_i = int(page_per_seq[i].item()) + last_len_i = int(kv_last_page_len_cpu[i].item()) + for j in range(pages_i): + page_id = int(page_table_cpu[i, j].item()) + if kv_layout == "NHD": + # NHD: [page_id, 0/1, page_size, num_heads, head_dim] + k_page = ref_kv_cache[page_id, 0] # [page_size, num_heads, head_dim] + v_page = ref_kv_cache[page_id, 1] + if j == pages_i - 1: + k_page = k_page[:last_len_i, :, :] + v_page = v_page[:last_len_i, :, :] + else: # HND + # HND: [page_id, 0/1, num_heads, page_size, head_dim] + k_page = ref_kv_cache[page_id, 0] # [num_heads, page_size, head_dim] + v_page = ref_kv_cache[page_id, 1] + if j == pages_i - 1: + k_page = k_page[:, :last_len_i, :] + v_page = v_page[:, :last_len_i, :] + # Transpose to NHD: [num_heads, page_size, head_dim] -> [page_size, num_heads, head_dim] + k_page = k_page.transpose(0, 1) + v_page = v_page.transpose(0, 1) + k_list.append(k_page) + v_list.append(v_page) + k_flat = torch.cat(k_list, dim=0) + v_flat = torch.cat(v_list, dim=0) + kv_indptr_tokens = torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(seq_lens, dim=0, dtype=torch.int32), + ] + ) + return k_flat, v_flat, kv_indptr_tokens + + +def create_workspace_buffers(device): + # Lazily initialize and reuse global workspace buffers + global global_workspace_buffer, global_xqa_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + workspace_size, dtype=torch.int8, device=device + ) + if global_xqa_workspace_buffer is None: + global_xqa_workspace_buffer = torch.zeros( + workspace_size, dtype=torch.int8, device=device + ) + return global_xqa_workspace_buffer, global_workspace_buffer + + +def create_output(q, o_dtype): + """Create output tensor for the given query and output dtype.""" + if o_dtype == "fp8": + o_scale = torch.rand(1).item() * 0.5 + 0.5 # Scale range: 0.5 ~ 1.0 + out = torch.empty(q.shape, dtype=torch.float8_e4m3fn, device=q.device) + else: + o_scale = 1.0 + out = torch.empty(q.shape, dtype=DTYPE_MAP[o_dtype], device=q.device) + + return out, o_scale + + +def get_last_page_len(seq_lens, page_size): + """Get the valid token count in the last page for each sequence""" + last_page_len = seq_lens % page_size + # If the sequence length is a multiple of page_size, the last page is full + last_page_len = torch.where(last_page_len == 0, page_size, last_page_len) + return last_page_len + + +@pytest.mark.skipif( + get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12], + reason="XQA is only supported on SM90, SM100, SM120 GPUs", +) +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 2, 1), + (4, 1, 32, 2, 5), + (128, 1, 64, 2, 6), + (256, 1, 64, 4, 8), + ], +) +@pytest.mark.parametrize("window_left", [-1, 127]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("bf16", "fp8", "bf16"), + ("fp16", "fp8", "fp16"), + ], +) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +def test_xqa_batch_decode( + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + kv_layout, +): + """Test xqa_batch_decode_with_kv_cache function. + + This test supports both NHD and HND layouts. + """ + if q_len_per_req > 1: + pytest.skip("xqa does not support speculative decoding yet") + + # Set up test parameters + torch.manual_seed(0) + head_dim = 128 + + # Generate random sequence lengths + num_qo_heads = num_kv_heads * head_grp_size + q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode( + batch_size, q_len_per_req, max_in_kv_len + ) + + # Create query tensor and related data + q, q_scale, ref_q = create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype) + q_indptr = generate_cumsum_lens(q_lens) + + # Create KV cache and related data + kv_cache, k_scale, v_scale, ref_kv_cache = create_kv_cache( + batch_size, + seq_lens, + page_size, + num_kv_heads, + head_dim, + kv_dtype, + "bf16" if q_dtype == "fp8" else q_dtype, + kv_layout, + ) + page_table, all_page_ids, page_per_seq = create_page_table( + batch_size, seq_lens, page_size + ) + kv_indptr = generate_cumsum_lens(page_per_seq) + kv_last_page_len = get_last_page_len(seq_lens, page_size) + + workspace_buffer, workspace_buffer_ref = create_workspace_buffers(GPU_DEVICE) + + # Create output tensor and related data + out, o_scale = create_output(q, o_dtype) + + sm_scale = float(1.0 / (head_dim**0.5)) + + # Build reference output + plan_params = { + "indptr": kv_indptr, + "indices": all_page_ids, + "last_page_len": kv_last_page_len.to(GPU_DEVICE), + "num_qo_heads": num_qo_heads, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "page_size": page_size, + "pos_encoding_mode": "NONE", + "kv_data_type": ref_kv_cache.dtype, + "q_data_type": ref_q.dtype, + "window_left": window_left, + } + if not enable_sink: + if q_len_per_req == 1: + wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer_ref, kv_layout, use_tensor_cores=True + ) + wrapper_ref.plan(**plan_params) + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + else: + # speculative decoding test + wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer_ref, kv_layout + ) + plan_params_prefill = plan_params.copy() + plan_params_prefill.update( + { + "qo_indptr": q_indptr, + "paged_kv_indptr": plan_params_prefill.pop("indptr"), + "paged_kv_indices": plan_params_prefill.pop("indices"), + "paged_kv_last_page_len": plan_params_prefill.pop("last_page_len"), + "head_dim_qk": plan_params_prefill.pop("head_dim"), + "causal": True, + "logits_soft_cap": 0.0, + } + ) + wrapper_ref.plan(**plan_params_prefill) + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + else: + # Construct flat K/V via helper + k_flat, v_flat, kv_indptr_tokens = flatten_paged_kv( + ref_kv_cache, + page_table, + seq_lens.to(GPU_DEVICE), + page_size, + kv_last_page_len, + kv_layout, + ) + sink = torch.rand(num_qo_heads, device=GPU_DEVICE, dtype=torch.float32) * 5 + output_ref = sink_attention_unified( + ref_q, + k_flat, + v_flat, + sink, + window_left, + True, + sm_scale, + mode="varlen", + batch_size=batch_size, + qo_indptr=q_indptr, + kv_indptr=kv_indptr_tokens, + ) + + # Run xqa_batch_decode_with_kv_cache function + output = flashinfer.decode.xqa_batch_decode_with_kv_cache( + q.contiguous(), + kv_cache, + workspace_buffer, + page_table, + seq_lens.to(GPU_DEVICE), + torch.max(seq_lens).item(), + q_scale * k_scale * sm_scale, # bmm1_scale + v_scale / o_scale, # bmm2_scale + window_left, # window_left + out=out, + enable_pdl=enable_pdl, + sinks=(sink if enable_sink else None), + kv_layout=kv_layout, + q_len_per_req=q_len_per_req, + ) + + # Verification + torch.testing.assert_close( + output, + output_ref, + rtol=1e-1 if kv_dtype == "fp8" else 1e-2, + atol=1e-1 if kv_dtype == "fp8" else 1e-2, + ) + + +if __name__ == "__main__": + # Run a simple test case + test_xqa_batch_decode( + batch_size=4, + q_len_per_req=1, + page_size=16, + num_kv_heads=2, + head_grp_size=1, + window_left=-1, + q_dtype="bf16", + kv_dtype="bf16", + o_dtype="bf16", + enable_pdl=True, + enable_sink=True, + max_in_kv_len=110, + kv_layout="NHD", + ) From da01b1bd8f9f22aec8c0eea189ad54860b034947 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Sat, 1 Nov 2025 23:31:36 -0700 Subject: [PATCH 014/130] test: Enable xfailed trtllm decode long seqlen tests and update microbenchmark (#2018) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description [tests/attention/test_trtllm_gen_attention.py](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/tests/attention/test_trtllm_gen_attention.py#L1021-L1076) was failing and therefore marked xfail. PR #2002 fixed the underlying root cause. Current PR thus removed the `xfail` marker so that these long seqlen cases could be fixed moving forward. Additionally, PR #2002 revealed a bug in the microbenchmark script where [trtllm_batch_decode_with_kv_cache](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/flashinfer/decode.py#L2082-L2083) explicitly requires the workspace to be zeroed before first use: ``` workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace ``` while the microbenchmark code does not zero out, causing undefined behavior such as IMAs that depend on the ordering of backends tested. Current PR fixes the issue by explicitly calling `workspace_buffer.zero_()` between testing different backends. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Improved stability of performance benchmarks by properly resetting workspace buffer between backend invocations. * **Tests** * Enabled previously skipped test for long sequence length handling. --- benchmarks/routines/attention.py | 8 ++++++++ tests/attention/test_trtllm_gen_attention.py | 1 - 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index bfebc37d4d..9dd2442eed 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -508,6 +508,8 @@ def run_backend_wrapper(backend): has_reference_output = False # Iterate over each backend: for cur_backend in backends: + # Clear workspace buffer to prevent unexpected interactions between backends. + workspace_buffer.zero_() if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone() if cur_backend == "fa2": @@ -975,6 +977,8 @@ def run_backend_wrapper(backend): has_reference_output = False # Iterate over each backend: for cur_backend in backends: + # Clear workspace buffer to prevent unexpected interactions between backends. + workspace_buffer.zero_() if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone() if cur_backend == "fa2": @@ -1427,6 +1431,8 @@ def run_backend_wrapper(backend): has_reference_output = False # Iterate over each backend: for cur_backend in backends: + # Clear workspace buffer to prevent unexpected interactions between backends. + workspace_buffer.zero_() if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone() if cur_backend == "fa2": @@ -1822,6 +1828,8 @@ def run_backend_wrapper(backend): has_reference_output = False # Iterate over each backend: for cur_backend in backends: + # Clear workspace buffer to prevent unexpected interactions between backends. + workspace_buffer.zero_() if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone() if cur_backend == "fa2": diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index dcb5e01a2a..4d1fe2891c 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1133,7 +1133,6 @@ def test_trtllm_batch_decode_long_sequence_length( head_dim, ): # Small number of test cases for long sequence length - pytest.xfail("trtllm-gen decode gets incorrect output with Long sequence length") _test_trtllm_batch_decode( "trtllm-gen", kv_layout, From 1e75bff99c175f5344ca079d40950106db157aee Mon Sep 17 00:00:00 2001 From: Maximilien Breughe <50598321+nvmbreughe@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:49:04 -0600 Subject: [PATCH 015/130] Updated decorator to support unspecified default (#2026) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Updated decorator to support unspecified default. This was causing issues when calling mm_fp4 without backend specified. Also added SM 110 as a supported backend on the cutlass backend (mm_fp4) ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * FP4 Cutlass GEMM now supports the SM110 GPU compute capability. * **Bug Fixes** * Kernels called without an explicit backend now consistently use the default backend. * **Tests** * Added a unit test to verify default backend selection and correct results when backend is omitted. --- flashinfer/gemm.py | 2 +- flashinfer/utils.py | 35 +++++++++++++++++++++++---------- tests/utils/test_decorators.py | 36 ++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 11 deletions(-) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index b561a67862..9f00cc6e25 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -1834,7 +1834,7 @@ def _trtllm_gemm_fp4_requirement( return True -@supported_compute_capability([100, 103, 120, 121]) +@supported_compute_capability([100, 103, 110, 120, 121]) def _cutlass_gemm_fp4_requirement( a: torch.Tensor, b: torch.Tensor, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 936d08380c..eb42e1291e 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -23,6 +23,7 @@ import torch.version from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version +import inspect from .jit.spdlog import gen_spdlog_module @@ -950,6 +951,9 @@ def backend_requirement( """ def decorator(func): + # Get the function signature once for reuse + sig = inspect.signature(func) + def is_backend_supported(backend, cc=None): # Is this backend present? if backend not in backend_checks: @@ -971,7 +975,9 @@ def is_compute_capability_supported(cc): for checker in backend_checks.values() ) - def is_problem_size_supported(*args, **kwargs): + # @note: this function does not automatically apply defaults to the arguments. + def _is_problem_size_supported(*args, **kwargs): + # At this point, kwargs should have defaults applied, so backend should be present backend = kwargs.get("backend") if backend not in backend_checks: raise BackendSupportedError( @@ -983,26 +989,34 @@ def is_problem_size_supported(*args, **kwargs): else: return req_checker(*args, **kwargs) + # @brief: Wrapper function that calls the orignal, decorated function, after applying a number of checks. + # @note that here we manually apply defaults to the arguments in the wrapper function when doing validation. @functools.wraps(func) def wrapper(*args, **kwargs): - backend = kwargs.get("backend") # skip_check is an optional argument that the decorator adds to any API function. # It prevents the performance overhead of checking. skip_check = kwargs.pop("skip_check", False) if not skip_check: + # Apply defaults from the function signature for validation + # This ensures that all parameters (including backend) have their default values + # if not explicitly provided by the caller + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + # Convert to kwargs for validation functions + kwargs_with_defaults = dict(bound_args.arguments) + + backend = kwargs_with_defaults.get("backend") + capability = None # Find the first tensor argument. # Assume all tensors are on the same device/capability. # We could consider check all tensors at a performance cost. tensor_arg = None - for arg in args: - if isinstance(arg, torch.Tensor): - tensor_arg = arg - if tensor_arg is None: - for value in kwargs.values(): - if isinstance(value, torch.Tensor): - tensor_arg = value + for value in kwargs_with_defaults.values(): + if isinstance(value, torch.Tensor): + tensor_arg = value + break if tensor_arg is not None: # Get compute capability from the first tensor @@ -1015,10 +1029,11 @@ def wrapper(*args, **kwargs): raise BackendSupportedError( f"{func.__name__} does not support backend '{backend}'{extra}" ) - if not is_problem_size_supported(*args, **kwargs): + if not _is_problem_size_supported(**kwargs_with_defaults): raise ValueError( f"Problem size is not supported for {func.__name__}" ) + return func(*args, **kwargs) wrapper.is_backend_supported = is_backend_supported diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index e0520b1d43..e0528cfd60 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -210,3 +210,39 @@ def my_documented_function(x, backend="backend"): # Verify that added methods still exist assert hasattr(my_documented_function, "is_backend_supported") assert hasattr(my_documented_function, "is_compute_capability_supported") + + +def test_backend_default_parameter(): + """Test that backend_requirement correctly uses default backend parameter when not specified.""" + if not torch.cuda.is_available(): + pytest.skip("Skipping CUDA tests (no GPU available)") + + # Get actual device capability + x = torch.randn(1, 1, device="cuda") + major, minor = torch.cuda.get_device_capability(x.device) + actual_capability = major * 10 + minor + + @supported_compute_capability([80, 86, 89, 90, actual_capability]) + def _cutlass_check(x, backend): + return x.shape[0] > 0 + + @supported_compute_capability([75, 80, 86, 89, 90, actual_capability]) + def _cudnn_check(x, backend): + return x.shape[0] > 0 + + @backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check}) + def my_kernel(x, backend="cudnn"): + return x * 2 + + x = torch.randn(10, 10, device="cuda") + + # Test that calling without backend argument uses the default "cudnn" + # This should work without raising an error + result = my_kernel(x) + assert result.shape == x.shape + assert torch.allclose(result, x * 2) + + # Test that explicitly passing a different backend also works + result2 = my_kernel(x, backend="cutlass") + assert result2.shape == x.shape + assert torch.allclose(result2, x * 2) From 2d68a6bb9860a59ca5a257d5ce527293d620af28 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Mon, 3 Nov 2025 18:43:38 -0800 Subject: [PATCH 016/130] release: Bump version for v0.5.1 release (#2031) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Update `version.txt` ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Version updated to 0.5.1 --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 8f0916f768..4b9fcbec10 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.5.0 +0.5.1 From d528f0c87e70391c71a54213cf7ac85a77211e89 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Tue, 4 Nov 2025 14:02:11 -0800 Subject: [PATCH 017/130] ci: Update cudnn version requirements in CI container (#2039) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description cuDNN versions specified in CI container setup (`docker/install/install_python_packages.sh`) are currently 9.11 and 9.12. In unit testing, this causes issues as `mm_fp4(backend='cudnn')` is not supported on Spark (sm121) for older cuDNN versions in cu130. Failure is due to cuDNN version shipped with container being too old. In the [latest container build pipeline output](https://github.com/flashinfer-ai/flashinfer/actions/runs/18778064727/job/53577233568#step:6:727), cudnn 9.13.0.50 is installed ``` #16 207.0 Requirement already satisfied: nvidia-cudnn-cu13>=9.12.0.46 in /opt/conda/envs/py312/lib/python3.12/site-packages (9.13.0.50) #16 207.0 Requirement already satisfied: nvidia-cublas in /opt/conda/envs/py312/lib/python3.12/site-packages (from nvidia-cudnn-cu13>=9.12.0.46) (13.0.0.19) ``` Current PR updates the minimum cudnn version for both [cu12](https://pypi.org/project/nvidia-cudnn-cu12/#history) and [cu13](https://pypi.org/project/nvidia-cudnn-cu13/#history) to 9.14.0.64. cudnn 9.13 --> unit test fails with 180 failed, 270 passed, 2790 skipped, 1 warning in 8.97s ``` # pytest tests/gemm/test_mm_fp4.py =================================================================================================================================================== test session starts =================================================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items ... FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-256] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_ FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-512] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_ ================================================================================================================================ 180 failed, 270 passed, 2790 skipped, 1 warning in 8.97s ================================================================================================================================= ``` cudnn 9.14 --> unit test passes with 450 passed, 2790 skipped, 1 warning in 5.37s ``` # pytest tests/gemm/test_mm_fp4.py =================================================================================================================================================== test session starts =================================================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items tests/gemm/test_mm_fp4.py ... ====================================================================================================================================== 450 passed, 2790 skipped, 1 warning in 5.37s ======================================================================================================================================= ``` ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Updated internal dependencies for improved system stability and compatibility. --- docker/install/install_python_packages.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/install/install_python_packages.sh b/docker/install/install_python_packages.sh index 465b9bf572..d898dae495 100644 --- a/docker/install/install_python_packages.sh +++ b/docker/install/install_python_packages.sh @@ -30,8 +30,8 @@ pip3 install responses pytest scipy build cuda-python nvidia-nvshmem-cu12 # Install cudnn package based on CUDA version if [[ "$CUDA_VERSION" == *"cu13"* ]]; then pip3 install --upgrade cuda-python==13.0 - pip3 install "nvidia-cudnn-cu13>=9.12.0.46" + pip3 install "nvidia-cudnn-cu13>=9.14.0.64" else pip3 install --upgrade cuda-python==12.* - pip3 install "nvidia-cudnn-cu12>=9.11.0.98" + pip3 install "nvidia-cudnn-cu12>=9.14.0.64" fi From f2cc526755a5baaf74722cda017bb446cee839e8 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Tue, 4 Nov 2025 14:48:01 -0800 Subject: [PATCH 018/130] test: Mark test_fp8_prefill.py as xfail on SM90 (#2038) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description `test_fp8_prefill.py` is currently failing on SM90, but consumes too much time to run/fail, causing unit-tests to time out. --Current PR marks it as xfail so that unit tests can progress forward.-- Update: Root cause of failure is because mixed precision attention is not available on `fa3` backend, but the attention prefill wrapper automatically selects `backend='fa3'` on SM90. Fix is to explicitly specify the `backend='fa2'` so that fa2 is always used. Status after fix: ``` $ pytest tests/attention/test_fp8_prefill.py =================================================================================================================================================== test session starts =================================================================================================================================================== ... collected 768 items tests/attention/test_fp8_prefill.py ............................................................................................................................................................................................................................................................................... [ 35%] ................................................................................................................................................................................................................................................................................................................... [ 75%] .............................................................................................................................................................................................. [100%] ======================================================================================================================================= 768 passed, 1 warning in 131.42s (0:02:11) ======================================================================================================================================== ``` ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Adjusted FP8/FP16 attention test configuration to explicitly select a backend during prefill/decoding, stabilizing test behavior across environments. * **Public API** * Constructors now accept an explicit backend parameter to allow selecting the backend used for KV cache operations. --- tests/attention/test_fp8_prefill.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/attention/test_fp8_prefill.py b/tests/attention/test_fp8_prefill.py index 414173f452..1b8ebc75cc 100644 --- a/tests/attention/test_fp8_prefill.py +++ b/tests/attention/test_fp8_prefill.py @@ -66,7 +66,7 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) wrapper_f16 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer, kv_layout, backend="fa2" ) wrapper_f16.plan( qo_indptr, @@ -90,7 +90,7 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( kv_data_fp8 = torch.cat([k_fp8, v_fp8], dim=1) wrapper_f8 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer, kv_layout, backend="fa2" ) wrapper_f8.plan( qo_indptr, @@ -156,7 +156,7 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer, kv_layout, backend="fa2" ) wrapper.plan( qo_indptr, @@ -173,7 +173,7 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( o_fp8 = wrapper.run(q, kv_data) decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + workspace_buffer, kv_layout, backend="fa2" ) decode_wrapper.plan( kv_indptr, From e1c1e2aa3c4f0bd79428ec5785155ac823c9fa89 Mon Sep 17 00:00:00 2001 From: FlashInfer Bot Date: Tue, 4 Nov 2025 17:52:22 -0800 Subject: [PATCH 019/130] Update Docker CI tags to 20251104-d528f0c (#2041) This PR updates the Docker CI image tags to the latest version: `20251104-d528f0c` Updated images: - flashinfer/flashinfer-ci-cu126:20251104-d528f0c - flashinfer/flashinfer-ci-cu128:20251104-d528f0c - flashinfer/flashinfer-ci-cu129:20251104-d528f0c - flashinfer/flashinfer-ci-cu130:20251104-d528f0c Auto-generated by [release-ci-docker workflow](https://github.com/flashinfer-ai/flashinfer/actions/runs/19084098717) ## Summary by CodeRabbit * **Chores** * Updated Docker image tags to latest versions for CUDA 12.6, 12.8, 12.9, and 13.0 distributions. Co-authored-by: yzh119 <11773619+yzh119@users.noreply.github.com> --- ci/docker-tags.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/docker-tags.yml b/ci/docker-tags.yml index ba3a947bc6..3619e1bf5b 100644 --- a/ci/docker-tags.yml +++ b/ci/docker-tags.yml @@ -1,4 +1,4 @@ -flashinfer/flashinfer-ci-cu126: 20251024-0e48aaf -flashinfer/flashinfer-ci-cu128: 20251024-0e48aaf -flashinfer/flashinfer-ci-cu129: 20251024-0e48aaf -flashinfer/flashinfer-ci-cu130: 20251024-0e48aaf +flashinfer/flashinfer-ci-cu126: 20251104-d528f0c +flashinfer/flashinfer-ci-cu128: 20251104-d528f0c +flashinfer/flashinfer-ci-cu129: 20251104-d528f0c +flashinfer/flashinfer-ci-cu130: 20251104-d528f0c From 9bc5bd55f77811b5fe3b063cf002de8d49882c49 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 4 Nov 2025 19:30:01 -0800 Subject: [PATCH 020/130] bugfix: fix failed unittest `test_green_ctx` and `test_jit_example` on spark (sm_121) (#1951) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description There are three failed unittests on spark (sm_121): * tests/utils/test_green_ctx.py * tests/utils/test_jit_example.py * tests/utils/test_sampling.py First one is because spark has small number of SMs (48) and we don't have a guard on green context splitting. Second one is an unknown issue (logits don't match with reference) and probably related to barriers on sm_121, xfail now and will fix later. The last one will be fixed by another PR from @bkryu , this PR fixes the first two issues. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Tests now pre-check GPU resources and auto-skip with informative messages including available and requested SM counts to avoid spurious failures. * Added a conditional xfail for GPUs with compute capability 12.1 to avoid false negatives on that hardware. * Tightened a sampling test by adding a relative tolerance for more robust numerical validation. * **Bug Fixes** * Improved runtime error handling to surface clearer guidance when GPU SM resources are insufficient. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- flashinfer/green_ctx.py | 78 +++++++++---- tests/utils/test_green_ctx.py | 192 ++++++++++++++++++++++---------- tests/utils/test_jit_example.py | 6 +- 3 files changed, 197 insertions(+), 79 deletions(-) diff --git a/flashinfer/green_ctx.py b/flashinfer/green_ctx.py index 0555cec212..09962fd467 100644 --- a/flashinfer/green_ctx.py +++ b/flashinfer/green_ctx.py @@ -170,12 +170,27 @@ def split_device_green_ctx( RuntimeError: when requested SM allocation exceeds device capacity: ``num_groups * rounded_min_count > total_device_sms`` """ - cu_dev = get_cudevice(dev) - resource = get_device_resource(cu_dev) - results, remaining = split_resource(resource, num_groups, min_count) - resources = results + [remaining] - streams = create_green_ctx_streams(cu_dev, resources) - return streams, resources + try: + cu_dev = get_cudevice(dev) + resource = get_device_resource(cu_dev) + results, remaining = split_resource(resource, num_groups, min_count) + resources = results + [remaining] + streams = create_green_ctx_streams(cu_dev, resources) + return streams, resources + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + raise RuntimeError( + f"{e}\n" + f"Failed to split device into {num_groups} groups with min_count={min_count}. " + f"This is likely due to insufficient number of SMs available on the device. " + f"Please reduce the number of groups or the minimum SM count per group." + ) from e + raise def split_device_green_ctx_by_sm_count( @@ -241,21 +256,40 @@ def split_device_green_ctx_by_sm_count( See `CUDA Green Contexts `_ for more details. """ - cu_dev = get_cudevice(dev) - resource = get_device_resource(cu_dev) + try: + cu_dev = get_cudevice(dev) + resource = get_device_resource(cu_dev) + + # Round sm counts to meet the alignment and granularity requirements + rounded_sm_counts = [] + for sm_count in sm_counts: + min_sm_count, sm_alignment = get_sm_count_constraint( + *get_compute_capability(dev) + ) + if sm_count <= 0: + raise ValueError(f"SM count must be positive, got {sm_count}") + rounded_sm_counts.append( + round_up(max(sm_count, min_sm_count), sm_alignment) + ) - # Round sm counts to meet the alignment and granularity requirements - rounded_sm_counts = [] - for sm_count in sm_counts: - min_sm_count, sm_alignment = get_sm_count_constraint( - *get_compute_capability(dev) + # Split the device into multiple green contexts + results, remaining = split_resource_by_sm_count( + cu_dev, resource, rounded_sm_counts ) - if sm_count <= 0: - raise ValueError(f"SM count must be positive, got {sm_count}") - rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment)) - - # Split the device into multiple green contexts - results, remaining = split_resource_by_sm_count(cu_dev, resource, rounded_sm_counts) - resources = results + [remaining] - streams = create_green_ctx_streams(cu_dev, resources) - return streams, resources + resources = results + [remaining] + streams = create_green_ctx_streams(cu_dev, resources) + return streams, resources + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + raise RuntimeError( + f"{e}\n" + f"Failed to split device with SM counts {sm_counts} (rounded to {rounded_sm_counts}). " + f"This is likely due to insufficient number of SMs available on the device. " + f"Please reduce the requested SM counts or use fewer partitions." + ) from e + raise diff --git a/tests/utils/test_green_ctx.py b/tests/utils/test_green_ctx.py index 4863dd5c51..99d6dc97bc 100644 --- a/tests/utils/test_green_ctx.py +++ b/tests/utils/test_green_ctx.py @@ -12,14 +12,30 @@ def test_green_ctx_creation( num_groups: int, min_count: int, ): - streams, resources = green_ctx.split_device_green_ctx( - torch.device(device), num_groups, min_count - ) + try: + streams, resources = green_ctx.split_device_green_ctx( + torch.device(device), num_groups, min_count + ) - assert len(resources) == num_groups + 1 - for resource in resources[:-1]: - sm_count = resource.sm.smCount - assert sm_count >= min_count + assert len(resources) == num_groups + 1 + for resource in resources[:-1]: + sm_count = resource.sm.smCount + assert sm_count >= min_count + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}" + ) + raise @pytest.mark.parametrize("device", ["cuda:0"]) @@ -30,19 +46,35 @@ def test_green_ctx_kernel_execution( num_groups: int, min_count: int, ): - streams, resources = green_ctx.split_device_green_ctx( - torch.device(device), num_groups, min_count - ) - num_partitions = num_groups + 1 - assert len(streams) == num_partitions - assert len(resources) == num_partitions - - for stream in streams: - with torch.cuda.stream(stream): - x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) - y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) - z = x @ y - print(z.shape) + try: + streams, resources = green_ctx.split_device_green_ctx( + torch.device(device), num_groups, min_count + ) + num_partitions = num_groups + 1 + assert len(streams) == num_partitions + assert len(resources) == num_partitions + + for stream in streams: + with torch.cuda.stream(stream): + x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) + y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) + z = x @ y + print(z.shape) + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}" + ) + raise @pytest.mark.parametrize("device", ["cuda:0"]) @@ -59,17 +91,33 @@ def test_split_device_green_ctx_by_sm_count_creation( device: str, sm_counts: list, ): - streams, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) - num_partitions = len(sm_counts) + 1 - assert len(resources) == num_partitions - assert len(streams) == num_partitions - - # Check that each partition has the expected SM count - for i, expected_sm_count in enumerate(sm_counts): - actual_sm_count = resources[i].sm.smCount - assert actual_sm_count >= expected_sm_count + try: + streams, resources = green_ctx.split_device_green_ctx_by_sm_count( + torch.device(device), sm_counts + ) + num_partitions = len(sm_counts) + 1 + assert len(resources) == num_partitions + assert len(streams) == num_partitions + + # Check that each partition has the expected SM count + for i, expected_sm_count in enumerate(sm_counts): + actual_sm_count = resources[i].sm.smCount + assert actual_sm_count >= expected_sm_count + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}" + ) + raise @pytest.mark.parametrize("device", ["cuda:0"]) @@ -85,19 +133,35 @@ def test_split_device_green_ctx_by_sm_count_kernel_execution( device: str, sm_counts: list, ): - streams, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) - num_partitions = len(sm_counts) + 1 - assert len(streams) == num_partitions - assert len(resources) == num_partitions - - for i, stream in enumerate(streams): - with torch.cuda.stream(stream): - x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) - y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) - z = x @ y - print(f"Partition {i}: {z.shape}") + try: + streams, resources = green_ctx.split_device_green_ctx_by_sm_count( + torch.device(device), sm_counts + ) + num_partitions = len(sm_counts) + 1 + assert len(streams) == num_partitions + assert len(resources) == num_partitions + + for i, stream in enumerate(streams): + with torch.cuda.stream(stream): + x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) + y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) + z = x @ y + print(f"Partition {i}: {z.shape}") + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}" + ) + raise @pytest.mark.parametrize("device", ["cuda:0"]) @@ -113,16 +177,32 @@ def test_split_device_green_ctx_by_sm_count_alignment( device: str, sm_counts: list, ): - _, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts - ) - - for resource in resources[:-1]: # Exclude remaining SMs - sm_count = resource.sm.smCount - assert sm_count > 0 - - min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint( - *green_ctx.get_compute_capability(torch.device(device)) + try: + _, resources = green_ctx.split_device_green_ctx_by_sm_count( + torch.device(device), sm_counts ) - assert sm_count >= min_sm_count - assert sm_count % sm_alignment == 0 + + for resource in resources[:-1]: # Exclude remaining SMs + sm_count = resource.sm.smCount + assert sm_count > 0 + + min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint( + *green_ctx.get_compute_capability(torch.device(device)) + ) + assert sm_count >= min_sm_count + assert sm_count % sm_alignment == 0 + except RuntimeError as e: + if ( + "CUDA error code=914" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e) + or "CUDA error code=915" in str(e) + or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e) + ): + # Get total SM count on the device + cu_dev = green_ctx.get_cudevice(torch.device(device)) + device_resource = green_ctx.get_device_resource(cu_dev) + total_sms = device_resource.sm.smCount + pytest.skip( + f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}" + ) + raise diff --git a/tests/utils/test_jit_example.py b/tests/utils/test_jit_example.py index fb169f1a7f..959f303914 100644 --- a/tests/utils/test_jit_example.py +++ b/tests/utils/test_jit_example.py @@ -11,7 +11,7 @@ gen_customize_single_prefill_module, ) from flashinfer.prefill import single_prefill_with_kv_cache_with_jit_module -from flashinfer.utils import MaskMode, is_sm90a_supported +from flashinfer.utils import MaskMode, is_sm90a_supported, get_compute_capability def test_single_decode_mask(): @@ -166,6 +166,10 @@ def test_flash_sigmoid(): torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2) +@pytest.mark.xfail( + get_compute_capability(torch.device("cuda:0")) == (12, 1), + reason="Numerical accuracy issue on SM 121 (Spark)", +) def test_dump_logits(): torch.manual_seed(42) variant_decl = r""" From 2580610b46ba5fa45532256ca77730acba2f5a03 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Tue, 4 Nov 2025 22:07:19 -0800 Subject: [PATCH 021/130] perf: Speed up fp4 quantization for small batch with swizzling for cutlass MoE (#2025) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Performance optimization for `fp4_quantize()` function. The performance issue was raised in issues #1734 and #2021 Observed behavior was slow performance when `is_sf_swizzled_layout=True` (as opposed to False). Root cause of the issue was * Excessive Padding Overhead: Swizzled layouts require row padding to tile boundaries where `SWIZZLED_128x4` pads to multiples of 128 rows and `SWIZZLED_8x4` pads to multiples of 8 rows * This means `For batch_size=1` with SWIZZLED_128x4: 127 out of 128 rows are padding (99.2% wasted work) * Sequential Processing: The original grid launch used grid.x = min(m, multiProcessorCount * numBlocksPerSM), so: For batch_size=1: only 1 block launched * This single block iterated sequentially over all 128 padded rows * Each padding row still computed scale factors, checked bounds, and performed conditional logic * No Fast Path: Every row (real or padding) went through the same expensive code path with multiple conditional branches The fix: 1. Kernel-Level Early Exit Fast Path (`quantization.cuh`): Added branch divergence optimization with separate handling for padding vs. data rows - Padding rows now execute ~10ร— fewer instructions; Eliminates memory loads/stores for input/output data on padding rows; Reduces register pressure and divergence overhead 2. Host-Level Parallel Grid Launch (`quantization.cu`): Modified grid calculation to launch blocks proportional to padded rows instead of actual rows: - For batch_size=1 with SWIZZLED_128x4: launches up to 128 blocks instead of 1; Each block processes 1 row in parallel instead of sequentially; overall tries to achieve full GPU occupancy even with small batch sizes `fp4_quantize()` performance before fix: ``` $ python3 bench_fp4_quantize.py +------------+---------------------+-------------------------+ | batch size | swizzled_times (us) | non_swizzled_times (us) | +------------+---------------------+-------------------------+ | 1.0 | 71.52 | 3.136 | | 2.0 | 37.152 | 3.168 | | 4.0 | 19.904 | 3.168 | | 8.0 | 11.296 | 3.2 | | 16.0 | 7.103 | 3.296 | | 32.0 | 4.96 | 3.376 | | 64.0 | 4.128 | 3.487 | | 128.0 | 3.808 | 3.648 | | 256.0 | 4.32 | 4.161 | | 512.0 | 5.472 | 5.184 | +------------+---------------------+-------------------------+ ``` After fix in current PR: ``` $ python3 bench_fp4_quantize.py +------------+---------------------+-------------------------+ | batch size | swizzled_times (us) | non_swizzled_times (us) | +------------+---------------------+-------------------------+ | 1.0 | 3.456 | 3.264 | | 2.0 | 3.488 | 3.296 | | 4.0 | 3.536 | 3.296 | | 8.0 | 3.52 | 3.296 | | 16.0 | 3.52 | 3.456 | | 32.0 | 3.696 | 3.488 | | 64.0 | 3.744 | 3.584 | | 128.0 | 3.936 | 3.776 | | 256.0 | 4.384 | 4.288 | | 512.0 | 5.568 | 5.248 | +------------+---------------------+-------------------------+ ``` where the `bench_fp4_quantize.py` script used to benchmark (adopted from #1734) : ``` from flashinfer.testing.utils import bench_gpu_time_with_cupti from flashinfer import fp4_quantize import torch import numpy as np import pandas as pd from tabulate import tabulate A_scale = torch.randn(16).cuda().float() bsz = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] swizzled_times = [] for bs in bsz: A = torch.randn(bs, 5120).cuda().to(torch.bfloat16) t = np.median(bench_gpu_time_with_cupti( lambda: fp4_quantize(A, A_scale, is_sf_swizzled_layout=True), dry_run_iters = 10, repeat_iters = 100, ) ) * 1000 swizzled_times.append(t) non_swizzled_times = [] for bs in bsz: A = torch.randn(bs, 5120).cuda().to(torch.bfloat16) t = np.median(bench_gpu_time_with_cupti( lambda: fp4_quantize(A, A_scale, is_sf_swizzled_layout=False), dry_run_iters = 10, repeat_iters = 100, ) ) * 1000 non_swizzled_times.append(t) summary_df = pd.DataFrame({ "batch size": bsz, "swizzled_times (us)": swizzled_times, "non_swizzled_times (us)": non_swizzled_times, }) # Round numeric columns to three decimals before printing summary_df_rounded = summary_df.copy() summary_df_rounded["batch size"] = summary_df_rounded["batch size"].astype(int) summary_df_rounded["swizzled_times (us)"] = summary_df_rounded["swizzled_times (us)"].round(3) summary_df_rounded["non_swizzled_times (us)"] = summary_df_rounded["non_swizzled_times (us)"].round(3) print(tabulate(summary_df_rounded, headers='keys', tablefmt='pretty', showindex=False)) ``` ## ๐Ÿ” Related Issues #1734 #2021 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Improved quantization for swizzled memory layouts by adjusting how effective processing rows are computed to better utilize GPU resources. * Added early-exit handling for padding-only rows so padding outputs are zeroed without processing data. * Ensured consistent zeroing of scale/format outputs for padded columns across all quantization paths. --- csrc/nv_internal/cpp/kernels/quantization.cu | 24 +++- .../tensorrt_llm/kernels/quantization.cuh | 116 +++++++++++------- 2 files changed, 94 insertions(+), 46 deletions(-) diff --git a/csrc/nv_internal/cpp/kernels/quantization.cu b/csrc/nv_internal/cpp/kernels/quantization.cu index 458cafd2f6..9021bd0847 100644 --- a/csrc/nv_internal/cpp/kernels/quantization.cu +++ b/csrc/nv_internal/cpp/kernels/quantization.cu @@ -70,6 +70,21 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper function for grid configuration with swizzled layouts + +inline int computeEffectiveRows(int m, QuantizationSFLayout layout) { + int effectiveRows = m; + bool isSfSwizzledLayout = (layout == QuantizationSFLayout::SWIZZLED_128x4 || + layout == QuantizationSFLayout::SWIZZLED_8x4); + if (isSfSwizzledLayout) { + int rowTile = (layout == QuantizationSFLayout::SWIZZLED_128x4) ? 128 : 8; + int numPaddedRows = (m + rowTile - 1) / rowTile * rowTile; // Round up to rowTile + effectiveRows = numPaddedRows; + } + return effectiveRows; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// // MXFP8 Quantization @@ -85,7 +100,8 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, dim3 block(std::min(int(padded_n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + int effectiveRows = computeEffectiveRows(m, layout); + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. cudaLaunchConfig_t config; @@ -177,7 +193,8 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_FP8_TO_FP4_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + int effectiveRows = computeEffectiveRows(m, layout); + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. auto* kernel_instance = useUE8M0 @@ -197,7 +214,8 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS dim3 block(std::min(int(n / CVT_ELTS_PER_THREAD), 512)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 2048u / block.x); - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + int effectiveRows = computeEffectiveRows(m, layout); + dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. auto* kernel_instance = useUE8M0 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh index 237b59eeaf..7abf2eb631 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh @@ -778,56 +778,86 @@ quantize_with_block_size( int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD; asm volatile("griddepcontrol.wait;"); + // Input tensor batch/row/col loops. + // Optimization: Iterate over actual rows first (hot path), then padding rows (cold path) + // This improves performance for small batch sizes with swizzled layout for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) { - for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { - for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { - std::optional optionalBatchIdx = batchIdx; - std::optional optionalNumRows = numRows; - - // The SF output pointer. - auto sf_out = cvt_quant_get_sf_out_offset( - optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, - layout); - - // The input tensor offset. - int64_t inOffset = - static_cast(batchIdx * numRows + rowIdx) * numColThreads + colIdx; - int64_t outOffset = - static_cast(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx; - - // Set the values to 0 of those are padded columns. - if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) { - // Dispatch the quantization kernel. - if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { - reinterpret_cast(out)[outOffset] = 0u; - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 || - quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { - reinterpret_cast(out)[outOffset] = 0ull; - } - } + // Early exit for padding-only blocks: if this block only processes padding rows, + // we can skip the batch loop and just zero out the scale factors + bool isRowPadding = (rowIdx >= numRows); + + if (isRowPadding) { + // Fast path: This row is entirely padding, only zero out scale factors. + // Note: Padding rows do NOT exist in the output tensor (which is sized [numRows, K]), + // they only exist in the swizzled scale factor layout. Do NOT write to output buffer here. + for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { + for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { + std::optional optionalBatchIdx = batchIdx; + std::optional optionalNumRows = numRows; + + // The SF output pointer. + auto sf_out = cvt_quant_get_sf_out_offset( + optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numColsForSf / SF_VEC_SIZE, SFout, + layout); - // Set the SF padding to 0. - if (rowIdx >= numRows || colIdx >= numColThreads) { // Set the SF padding to 0. if (sf_out != nullptr) { sf_out[0] = 0x00; } - } else { - // Load the input vector. - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - - // Dispatch the quantization kernel. - if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { - reinterpret_cast(out)[outOffset] = - cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) { - reinterpret_cast(out)[outOffset] = - cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, - sf_out); - } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { - reinterpret_cast(out)[outOffset] = - cvt_warp_fp16_to_mxfp8(in_vec, sf_out); + } + } + } else { + // Normal path: This row contains actual data + for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) { + for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) { + std::optional optionalBatchIdx = batchIdx; + std::optional optionalNumRows = numRows; + + // The SF output pointer. + auto sf_out = cvt_quant_get_sf_out_offset( + optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numColsForSf / SF_VEC_SIZE, SFout, + layout); + + // The input tensor offset. + int64_t inOffset = + static_cast(batchIdx * numRows + rowIdx) * numColThreads + colIdx; + int64_t outOffset = + static_cast(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx; + + // Set the values to 0 of those are padded columns. + if (colIdx >= numColThreads && colIdx < numPaddedColThreads) { + // Dispatch the quantization kernel. + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { + reinterpret_cast(out)[outOffset] = 0u; + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 || + quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { + reinterpret_cast(out)[outOffset] = 0ull; + } + } + + // Process actual data or padding + if (colIdx >= numColThreads) { + // Column padding: Set the SF padding to 0. + if (sf_out != nullptr) { + sf_out[0] = 0x00; + } + } else { + // Load the input vector. + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + + // Dispatch the quantization kernel. + if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) { + reinterpret_cast(out)[outOffset] = + cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) { + reinterpret_cast(out)[outOffset] = + cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, + sf_out); + } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) { + reinterpret_cast(out)[outOffset] = + cvt_warp_fp16_to_mxfp8(in_vec, sf_out); + } } } } From 579012b9f0b696d0f6cd5cd526d778e7119740d1 Mon Sep 17 00:00:00 2001 From: Jimmy Zhou <79552142+jimmyzho@users.noreply.github.com> Date: Wed, 5 Nov 2025 01:08:19 -0500 Subject: [PATCH 022/130] Support cc common check decorator for empty backends (#2015) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Improved backend/compute-capability validation with clearer errors and correct fallback when backend-specific checks are absent. * **New Features** * Decorated functions expose runtime attributes to query backend availability and choices. * Default-backend behavior: kernels use a default when none is passed. * **Compatibility** * Expanded supported compute-capability set and raised minimum cuDNN package requirements. * **Tests** * Added tests for empty-backend common-checks and default-backend behavior. * **Chores** * Version bumped to 0.5.1. --- flashinfer/utils.py | 104 +++++++++++++++++-------- tests/utils/test_decorators.py | 136 +++++++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 31 deletions(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index eb42e1291e..3aae147896 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -877,6 +877,7 @@ def backend_requirement( An optional function that performs additional validation checks common to all backends. Should accept the same arguments as the decorated function and return True if requirements are met, False otherwise. + In the case where the kernel function does not have any specific backends, this can be decorated with @supported_compute_capability to specify the function's supported compute capabilities. Returns ------- @@ -927,17 +928,17 @@ def backend_requirement( ... # Backend invocation ... pass ... - >>> # Check if backend is supported - >>> my_attention_kernel.is_backend_supported("cutlass") - True - >>> # Check if backend supports specific compute capability - >>> my_attention_kernel.is_backend_supported("cutlass", 75) - False - >>> my_attention_kernel.is_backend_supported("cutlass", 80) - True - >>> # Check if any backend supports a compute capability - >>> my_attention_kernel.is_compute_capability_supported(75) - True + >>> # Example with kernel function with no backend requirements + >>> @supported_compute_capability([80, 86, 89, 90]) + ... def _common_size_check(q, k, v): + ... return True + ... + >>> @backend_requirement( + ... backend_checks={}, # Empty backend_checks + ... common_check=_common_size_check + ... ) + ... def backend_agnostic_kernel(q, k, v): + ... pass Notes ----- @@ -955,30 +956,50 @@ def decorator(func): sig = inspect.signature(func) def is_backend_supported(backend, cc=None): - # Is this backend present? - if backend not in backend_checks: + # No backend-specific checks + if not has_backend_choices(): + raise ValueError( + f"Invalid is_backend_supported call: no backend choices for {func.__name__}" + ) + else: + # Is this backend present? + if backend not in backend_checks: + return False + req_checker = backend_checks[backend] + # If user just wants to check if the backend is supported (regardless of compute capability), return True + if cc is None: + return True + # Check compute capability support via attribute on requirement function + elif hasattr(req_checker, "is_compute_capability_supported"): + return req_checker.is_compute_capability_supported(cc) return False - req_checker = backend_checks[backend] - # If user just wants to check if the backend is supported (regardless of compute capability), return True - if cc is None: - return True - # Check compute capability support via attribute on requirement function - elif hasattr(req_checker, "is_compute_capability_supported"): - return req_checker.is_compute_capability_supported(cc) - return False def is_compute_capability_supported(cc): - # True if any backend requirement supports this cc - return any( - hasattr(checker, "is_compute_capability_supported") - and checker.is_compute_capability_supported(cc) - for checker in backend_checks.values() - ) + # In case there is only 1 implicit backend, the compute capability support needs to be added to the common check + if not has_backend_choices(): + # No backend-specific checks, only check common_check + if not hasattr(common_check, "is_compute_capability_supported"): + raise ValueError( + f"Invalid is_compute_capability_supported call: {common_check.__name__} does not have is_compute_capability_supported decorator" + ) + return common_check.is_compute_capability_supported(cc) + else: + # True if any backend requirement supports this cc + return any( + hasattr(checker, "is_compute_capability_supported") + and checker.is_compute_capability_supported(cc) + for checker in backend_checks.values() + ) # @note: this function does not automatically apply defaults to the arguments. def _is_problem_size_supported(*args, **kwargs): # At this point, kwargs should have defaults applied, so backend should be present backend = kwargs.get("backend") + + # Handle empty backend_checks case + if not has_backend_choices(): + return common_check(*args, **kwargs) + if backend not in backend_checks: raise BackendSupportedError( f"Backend '{backend}' is not supported for {func.__name__}" @@ -989,6 +1010,14 @@ def _is_problem_size_supported(*args, **kwargs): else: return req_checker(*args, **kwargs) + def has_backend_choices() -> bool: + # Whether there are any backend choices to make + return bool(backend_checks) + + def has_backend(backend: str) -> bool: + # Whether the given backend exists in the API + return backend in backend_checks + # @brief: Wrapper function that calls the orignal, decorated function, after applying a number of checks. # @note that here we manually apply defaults to the arguments in the wrapper function when doing validation. @functools.wraps(func) @@ -1024,11 +1053,22 @@ def wrapper(*args, **kwargs): major, minor = get_compute_capability(tensor_arg.device) capability = major * 10 + minor - if not is_backend_supported(backend, capability): - extra = f" with capability {capability}" if capability else "" - raise BackendSupportedError( - f"{func.__name__} does not support backend '{backend}'{extra}" + if not has_backend_choices() and common_check is None: + raise ValueError( + f"Invalid @backend_requirement decorator usage: no backend choices and no common_check for {func.__name__}" ) + + if has_backend_choices(): + if not is_backend_supported(backend, capability): + extra = f" with capability {capability}" if capability else "" + raise BackendSupportedError( + f"{func.__name__} does not support backend '{backend}'{extra}" + ) + else: + if not is_compute_capability_supported(capability): + raise BackendSupportedError( + f"{func.__name__} does not support compute capability {capability}" + ) if not _is_problem_size_supported(**kwargs_with_defaults): raise ValueError( f"Problem size is not supported for {func.__name__}" @@ -1038,6 +1078,8 @@ def wrapper(*args, **kwargs): wrapper.is_backend_supported = is_backend_supported wrapper.is_compute_capability_supported = is_compute_capability_supported + wrapper.has_backend = has_backend + wrapper.has_backend_choices = has_backend_choices return wrapper return decorator diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index e0528cfd60..4f052019df 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -115,6 +115,142 @@ def my_kernel(x, backend="cudnn"): assert my_kernel.is_compute_capability_supported(70) is False # neither has it +def test_backend_requirement_empty_backends_with_common_check_cc(): + """Test backend_requirement with empty backend_checks but common_check with compute capability.""" + + # Made up compute capability + @supported_compute_capability([42]) + def _common_check(x): + # Common check with compute capability restrictions + return x.shape[0] <= 1024 + + @backend_requirement( + {}, # Empty backend_checks + common_check=_common_check, + ) + def unsupported_kernel(x): + return x * 2 + + # Check methods + assert hasattr(unsupported_kernel, "is_backend_supported") + assert hasattr(unsupported_kernel, "is_compute_capability_supported") + + # Check compute capability support (only common_check) + assert unsupported_kernel.is_compute_capability_supported(42) is True + assert unsupported_kernel.is_compute_capability_supported(75) is False + + # The following tests are for when no backend choices are provided, where + # `is_backend_supported` is undefined behaviour and will raise error. + # We also enforce the `common_check` function when using `@backend_requirement` decorator. + # It must also be decorated with `@supported_compute_capability`. + + # Raise error: is_backend_supported cannot be called with no backend choices. + for backend in [ + ("random_backend", 42), + ("random_backend", 75), + (None, 42), + (None, 75), + ]: + with pytest.raises( + ValueError, + match="Invalid is_backend_supported call: no backend choices for unsupported_kernel", + ): + unsupported_kernel.is_backend_supported(backend[0], backend[1]) + + # Test compute capability support during kernel runtime + x = torch.randn(10, 10, device="cuda") + + # Error: no real compute capability is supported + with pytest.raises( + BackendSupportedError, match="does not support compute capability" + ): + unsupported_kernel(x) + + actual_capability = torch.cuda.get_device_capability(x.device) + major, minor = actual_capability + actual_capability = major * 10 + minor + + @supported_compute_capability([actual_capability]) + def _common_check(x): + return True + + @backend_requirement( + {}, + common_check=_common_check, + ) + def supported_kernel(x): + return x * 2 + + assert supported_kernel.is_compute_capability_supported(actual_capability) is True + + # Raise error: is_backend_supported cannot be called with no backend choices. + with pytest.raises( + ValueError, + match="Invalid is_backend_supported call: no backend choices for supported_kernel", + ): + supported_kernel.is_backend_supported(None, actual_capability) + assert supported_kernel.has_backend("random_backend") is False + + result = supported_kernel(x) + assert result.shape == x.shape + + # Enforce the `common_check` function to have `is_compute_capability_supported` decorator. + def _bad_common_check(x): + return True + + @backend_requirement( + {}, + common_check=_bad_common_check, + ) + def bad_kernel(x): + return x * 2 + + with pytest.raises( + ValueError, + match="Invalid is_compute_capability_supported call: _bad_common_check does not have is_compute_capability_supported decorator", + ): + bad_kernel.is_compute_capability_supported(42) + + # Enforce `common_check` function in @backend_requirement decorator. + @backend_requirement({}) + def kernel_no_common_check(x): + return x * 2 + + with pytest.raises( + ValueError, + match="Invalid @backend_requirement decorator usage: no backend choices and no common_check for kernel_no_common_check", + ): + x = torch.randn(10, 10, device="cuda") + kernel_no_common_check(x) + + +def test_has_backend(): + """Test the has_backend method.""" + + @backend_requirement({"cudnn": lambda x: True, "cutlass": lambda x: True}) + def my_kernel(x, backend="cudnn"): + return x * 2 + + assert my_kernel.has_backend("cudnn") is True + assert my_kernel.has_backend("cutlass") is True + assert my_kernel.has_backend("random_backend") is False + + +def test_has_backend_choices(): + """Test the has_backend_choices method.""" + + @backend_requirement({"cudnn": lambda x: True, "cutlass": lambda x: True}) + def my_kernel(x, backend="cudnn"): + return x * 2 + + @backend_requirement({}) + def my_kernel_no_backend(x): + return x * 2 + + assert my_kernel.has_backend_choices() is True + assert my_kernel_no_backend.has_backend_choices() is False + + def test_backend_requirement_wrapped_function(): """Test the backend_requirement decorator's wrapped function.""" if not torch.cuda.is_available(): From 6d19a75a6963b0996b7b2cfe5f9e385c5cc41a66 Mon Sep 17 00:00:00 2001 From: qsang-nv <200703406+qsang-nv@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:26:01 +0800 Subject: [PATCH 023/130] use scalar for kv_scale in xqa (#2033) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Breaking Changes** * Public xqa/xqa_mla entry points now accept kv_scale as a plain float (default 1.0) instead of a 1-element tensor. Update call sites accordingly. * **Documentation** * Docstrings updated to reflect kv_scale as float. * **Tests** * Tests updated to pass scalar kv_scale, with added parameterization and conditional skip for FP8 kv-cache scenarios. --------- Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --- csrc/flashinfer_xqa_binding.cu | 7 +++---- csrc/xqa/mha.cu | 15 ++++++--------- csrc/xqa/mha.h | 21 ++++++++------------- csrc/xqa/mha_sm90.cu | 17 +++++++---------- csrc/xqa/mla_sm120.cu | 16 ++++++---------- csrc/xqa/xqa_wrapper.cu | 13 +++++-------- flashinfer/decode.py | 4 +--- flashinfer/xqa.py | 30 ++++++++++-------------------- tests/attention/test_xqa.py | 19 ++++++++++++++----- 9 files changed, 60 insertions(+), 82 deletions(-) diff --git a/csrc/flashinfer_xqa_binding.cu b/csrc/flashinfer_xqa_binding.cu index 40b4168a9b..8556fb5e48 100644 --- a/csrc/flashinfer_xqa_binding.cu +++ b/csrc/flashinfer_xqa_binding.cu @@ -19,9 +19,8 @@ #if MLA_WRAPPER void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, - int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, - TensorView kvCacheScale, TensorView semaphores, TensorView scratch, - bool enable_pdl); + int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, + TensorView semaphores, TensorView scratch, bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla); @@ -34,7 +33,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK #endif TensorView q, tvm::ffi::Optional attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, - TensorView seqLen, int64_t batchSize, TensorView kvCacheScale, + TensorView seqLen, int64_t batchSize, double kvCacheScale, #if SPEC_DEC int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, #endif diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index 1693d025f0..715267bedc 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -1301,8 +1301,7 @@ CUBIN_EXPORT __global__ #endif #endif uint32_t const batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V - // cache. Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head, uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) { assert(allowMultiBlockMode || gridDim.x == 1); @@ -1503,7 +1502,7 @@ CUBIN_EXPORT __global__ }; if (warpIdx.z == 0) { float const qkScale = - qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * + qScale * (isKVCacheQuantized ? kvCacheScale : 1.f) * rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. CircIdx idxCurrSMemKBuf{nbKBuffers - 1}; auto const getSMemKTile = [&](uint32_t idx) -> SharedMem::KSmemBuffer& { @@ -2156,7 +2155,7 @@ CUBIN_EXPORT __global__ } } - float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F); + float voScale = (isKVCacheQuantized ? kvCacheScale : 1.F); if (seqIterInit < nbSeqIters) { // otherwise rcpRowSum will be NAN. // The attention sinks are moved to the multi-block reduction part if the multi-block is // enabled. @@ -2410,8 +2409,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( BeamSearchParams const beamSearchParams, #endif uint32_t const batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. - // Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head, uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) { #if SPEC_DEC @@ -2469,8 +2467,7 @@ void launchMHA( BeamSearchParams const& beamSearchParams, #endif uint32_t batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. - // Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif @@ -2571,7 +2568,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float const* __restrict__ kvCacheScale, + float kvCacheScale, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif diff --git a/csrc/xqa/mha.h b/csrc/xqa/mha.h index 43aed55f95..ee4584ee84 100644 --- a/csrc/xqa/mha.h +++ b/csrc/xqa/mha.h @@ -115,8 +115,7 @@ void launchMHA( BeamSearchParams const& beamSearchParams, #endif uint32_t batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. - // Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif @@ -131,7 +130,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float const* __restrict__ kvCacheScale, + float kvCacheScale, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif @@ -166,8 +165,7 @@ void launchHopperF8MHA( BeamSearchParams const& beamSearchParams, #endif uint32_t batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. - // Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif @@ -181,8 +179,7 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, - uint32_t const* seqLen, uint32_t batchSize, - float const* __restrict__ kvCacheScale, + uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif @@ -197,11 +194,10 @@ void launchMLA( GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, // device pointer. shape: - // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or - // [batchSize][maxNbPagesPerSeq] (Layout 1) + // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] + // (Layout 0) or [batchSize][maxNbPagesPerSeq] (Layout 1) uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. - // Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream); void launchMLAFlashInfer( @@ -214,8 +210,7 @@ void launchMLAFlashInfer( // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or // [batchSize][maxNbPagesPerSeq] (Layout 1) uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. - // Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream); diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu index 495f4d2d46..d0de67c372 100644 --- a/csrc/xqa/mha_sm90.cu +++ b/csrc/xqa/mha_sm90.cu @@ -626,8 +626,7 @@ __launch_bounds__(128 * 3) BeamSearchParams const beamSearchParams, #endif uint32_t const batchSize, - float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and - // V cache. Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. __grid_constant__ CUtensorMap const tensorMapVLLMK, __grid_constant__ CUtensorMap const tensorMapVLLMV, #if SPEC_DEC @@ -773,7 +772,7 @@ __launch_bounds__(128 * 3) } float const qkScale = - qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * + qScale * (isKVCacheQuantized ? kvCacheScale : 1.f) * rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. uint32_t const warpRank = warpIdx.x; @@ -962,7 +961,7 @@ __launch_bounds__(128 * 3) #else constexpr float oScale = 1.F; #endif - float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * oScale; + float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale : 1.f) * oScale; Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction. gmma::fence(); @@ -1316,7 +1315,7 @@ __launch_bounds__(128 * 3) headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; IOHead const& inKHead = qkv[inputKHeadOffset]; uint32_t const lane = laneId(); - float const rcpKScale = 1.F / kvCacheScale[0]; + float const rcpKScale = 1.F / kvCacheScale; #if ROPE_STYLE == 0 constexpr bool isNeox = false; auto const pairs = @@ -1375,7 +1374,7 @@ __launch_bounds__(128 * 3) (headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; IOHead const& inVHead = qkv[inputVHeadOffset]; uint32_t const lane = laneId(); - float const rcpVScale = 1.F / kvCacheScale[0]; + float const rcpVScale = 1.F / kvCacheScale; constexpr bool isNeox = false; auto const pairs = loadHead(inVHead, lane) * rcpVScale; @@ -2931,8 +2930,7 @@ void launchHopperF8MHA( BeamSearchParams const& beamSearchParams, #endif uint32_t batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. - // Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif @@ -3044,8 +3042,7 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, - uint32_t const* seqLen, uint32_t batchSize, - float const* __restrict__ kvCacheScale, + uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif diff --git a/csrc/xqa/mla_sm120.cu b/csrc/xqa/mla_sm120.cu index 2396fb8c5b..ffcf8ab3c5 100644 --- a/csrc/xqa/mla_sm120.cu +++ b/csrc/xqa/mla_sm120.cu @@ -395,8 +395,7 @@ struct KernelArgs { OutputHead* __restrict__ const& output; // [totalNbIntputTokens][nbQHeads] KVCacheList const& cacheList; uint32_t const& batchSize; - float const* __restrict__ const& kvCacheScale; // Device memory scalar. Same scale for K and V - // cache. Used only for int8/fp8 KV cache. + float kvCacheScale; // Same scale for K and V cache. Used only for int8/fp8 KV cache. Vec* __restrict__ const& cgaXBuf; // [totalNbInputTokens][maxNbSubSeq] uint32_t* __restrict__ const& semaphores; // [totalNbInputTokens] @@ -449,7 +448,7 @@ struct Producer { __syncthreads(); #endif if (threadIdx.x == 0) { - smem.qkScaleLog2e = args.qScale * args.kvCacheScale[0] * log2e; + smem.qkScaleLog2e = args.qScale * args.kvCacheScale * log2e; } if (threadIdx.x < headGrpSize) { @@ -1228,7 +1227,7 @@ __device__ inline void Consumer::compute() { ThrdRegRowMax const accRowSum = loadShmRowMax(smem.accRowSum[tileIdx.x], tileBase.y, lane); - float const xvScale = computeRowSumFromF8 ? args.kvCacheScale[0] : args.kvCacheScale[0] * xScale; + float const xvScale = computeRowSumFromF8 ? args.kvCacheScale : args.kvCacheScale * xScale; WarpOutputTile const output = finalize(acc, accRowSum, xvScale, lane); bool const isMultiBlockMode = (nbSubSeq != 1); @@ -1553,8 +1552,7 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha float const qScale, OutputHead* __restrict__ const output, // [totalNbIntputTokens][nbQHeads] KVCacheList const cacheList, uint32_t const batchSize, - float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and V - // cache. Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. Vec* __restrict__ const cgaXBuf, // [totalNbInputTokens][maxNbSubSeq] uint32_t* __restrict__ const semaphores = nullptr, // [totalNbInputTokens] @@ -1657,8 +1655,7 @@ void launchMLA( KVCachePageIndex const* kvCachePageList, // device pointer. shape: // [batchSize][maxNbPagesPerSeq] (Layout 1) uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. - // Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) { #if IS_MLA @@ -1779,8 +1776,7 @@ void launchMLAFlashInfer( KVCachePageIndex const* kvCachePageList, // device pointer. shape: // [batchSize][maxNbPagesPerSeq] (Layout 1) uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. - // Used only for int8/fp8 KV cache. + float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) { #if IS_MLA diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index bbe314b7e3..1ac25fcf91 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -22,9 +22,8 @@ using tvm::ffi::Optional; #if MLA_WRAPPER void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, - int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, - TensorView kvCacheScale, TensorView semaphores, TensorView scratch, - bool enable_pdl) { + int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, + TensorView semaphores, TensorView scratch, bool enable_pdl) { auto stream = get_stream(output.device()); // Extract strides from TensorView (in elements, not bytes) @@ -39,8 +38,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp reinterpret_cast(vCacheVLLM.data_ptr()), reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, - reinterpret_cast(kvCacheScale.data_ptr()), - reinterpret_cast(semaphores.data_ptr()), + kvCacheScale, reinterpret_cast(semaphores.data_ptr()), reinterpret_cast(scratch.data_ptr()), enable_pdl, kv_stride_page, kv_stride_token, kv_stride_head, stream); } @@ -53,7 +51,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK #endif TensorView q, Optional attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, - TensorView seqLen, int64_t batchSize, TensorView kvCacheScale, + TensorView seqLen, int64_t batchSize, double kvCacheScale, #if SPEC_DEC int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, #endif @@ -78,8 +76,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK reinterpret_cast(kCacheVLLM.data_ptr()), reinterpret_cast(vCacheVLLM.data_ptr()), reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, - reinterpret_cast(seqLen.data_ptr()), batchSize, - reinterpret_cast(kvCacheScale.data_ptr()), + reinterpret_cast(seqLen.data_ptr()), batchSize, kvCacheScale, #if SPEC_DEC qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), reinterpret_cast(mask.data_ptr()), diff --git a/flashinfer/decode.py b/flashinfer/decode.py index d418e5cd90..a85a1b846c 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2461,9 +2461,7 @@ def xqa_batch_decode_with_kv_cache( page_size, sinks=sinks_new, q_scale=q_scale_value, - kv_scale=torch.tensor( - [kv_scale_value], dtype=torch.float32, device=query.device - ), + kv_scale=kv_scale_value, sliding_win_size=window_left + 1 if window_left >= 0 else 0, kv_layout=kv_layout, sm_count=sm_count, diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index fba5045d74..fd75e34f87 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -67,7 +67,7 @@ def xqa( max_seq_len: int, seq_lens: torch.Tensor, batch_size: int, - kv_scale: torch.Tensor, + kv_scale: float, semaphores: torch.Tensor, workspace_buffer: torch.Tensor, enable_pdl: bool, @@ -111,7 +111,7 @@ def _fake_xqa( max_seq_len: int, seq_lens: torch.Tensor, batch_size: int, - kv_scale: torch.Tensor, + kv_scale: float, semaphores: torch.Tensor, workspace_buffer: torch.Tensor, ) -> None: @@ -135,7 +135,7 @@ def xqa( page_size: int, sinks: Optional[torch.Tensor] = None, q_scale: float = 1.0, - kv_scale: Optional[torch.Tensor] = None, + kv_scale: float = 1.0, sliding_win_size: int = 0, kv_layout: str = "NHD", sm_count: Optional[int] = None, @@ -184,10 +184,8 @@ def xqa( If None, no attention sinks are used. q_scale : float, default=1.0 Scale factor for query tensor. - kv_scale : Optional[torch.Tensor], default=None - Scale factor for KV cache with shape ``[1]``. - Data type should be torch.float32. - If None, defaults to 1.0. + kv_scale : float, default=1.0 + Scale factor for KV cache. sliding_win_size : int, default=0 Sliding window size for attention. If 0, no sliding window is used. kv_layout : str, default="NHD" @@ -214,9 +212,6 @@ def xqa( if sm_count is None: sm_count = get_device_sm_count(q.device) - if kv_scale is None: - kv_scale = torch.ones(1, dtype=torch.float32, device=q.device) - enable_pdl = enable_pdl if enable_pdl is not None else device_support_pdl(q.device) # Infer parameters from tensors @@ -316,7 +311,7 @@ def xqa_mla( max_seq_len: int, seq_lens: torch.Tensor, batch_size: int, - kv_scale: torch.Tensor, + kv_scale: float, semaphores: torch.Tensor, workspace_buffer: torch.Tensor, enable_pdl: bool, @@ -352,7 +347,7 @@ def _fake_xqa_mla( max_seq_len: int, seq_lens: torch.Tensor, batch_size: int, - kv_scale: torch.Tensor, + kv_scale: float, semaphores: torch.Tensor, workspace_buffer: torch.Tensor, enable_pdl: bool, @@ -375,7 +370,7 @@ def xqa_mla( semaphores: torch.Tensor, page_size: int, q_scale: float = 1.0, - kv_scale: Optional[torch.Tensor] = None, + kv_scale: float = 1.0, sm_count: Optional[int] = None, enable_pdl: Optional[bool] = None, ) -> None: @@ -412,10 +407,8 @@ def xqa_mla( Size of each page in the paged KV cache. Must be one of [16, 32, 64, 128]. q_scale : float, default=1.0 Scale factor for query tensor. - kv_scale : Optional[torch.Tensor], default=None - Scale factor for KV cache with shape ``[1]``. - Data type should be torch.float32. - If None, defaults to 1.0. + kv_scale : float, default=1.0 + Scale factor for KV cache. sm_count : Optional[int], default=None Number of streaming multiprocessors to use. If None, will be inferred from the device. @@ -435,9 +428,6 @@ def xqa_mla( if sm_count is None: sm_count = get_device_sm_count(q.device) - if kv_scale is None: - kv_scale = torch.ones(1, dtype=torch.float32, device=q.device) - enable_pdl = enable_pdl if enable_pdl is not None else device_support_pdl(q.device) # Infer parameters from tensors diff --git a/tests/attention/test_xqa.py b/tests/attention/test_xqa.py index 4e81f72bd5..5701bdc1b8 100644 --- a/tests/attention/test_xqa.py +++ b/tests/attention/test_xqa.py @@ -29,7 +29,6 @@ def div_up(a, b): sm_count = props.multi_processor_count beam_width = 1 -q_scale = 1.0 class CacheSeq: @@ -181,6 +180,8 @@ def ref_attention( @pytest.mark.parametrize("valid_elems_per_head", [32, 128]) @pytest.mark.parametrize("head_grp_size", [8, 16]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +@pytest.mark.parametrize("kv_scale", [1.0, 0.5]) +@pytest.mark.parametrize("q_scale", [1.0, 0.5]) def test_xqa( batch_size, nb_k_heads, @@ -194,7 +195,11 @@ def test_xqa( use_sliding_window, enable_pdl, kv_layout, + kv_scale, + q_scale, ): + if kv_scale != 1.0 and fp8_kv_cache is False: + pytest.skip("kv cache scale works only for fp8 kv cache") set_random_seed(42) nb_q_heads = nb_k_heads * head_grp_size @@ -347,7 +352,7 @@ def cache_head_at( ) seq_len_list.fill_(seq_len) - kv_cache_scale = torch.ones(1, dtype=torch.float32, device="cuda") + kv_cache_scale = kv_scale nb_seq = nb_k_heads * batch_size nb_semaphores = round_up(nb_seq, 2) + 2 + nb_seq + 2 @@ -406,7 +411,7 @@ def cache_head_at( v_cache_seq=v_cache_seq, seq_len=seq_len, q_scale=q_scale, - kv_scale=kv_cache_scale[0], + kv_scale=kv_cache_scale, x_scale=1.0, attention_sinks=attention_sinks[idx_k_head, :] if use_attention_sinks @@ -443,6 +448,8 @@ def cache_head_at( get_compute_capability(torch.device(device="cuda"))[0] not in [12], reason="XQA mla is only supported on SM120 GPUs", ) +@pytest.mark.parametrize("kv_scale", [1.0, 0.5]) +@pytest.mark.parametrize("q_scale", [1.0, 0.5]) @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("seq_len", [2, 15, 256, 514, 2048]) @pytest.mark.parametrize("batch_size", [1, 2]) @@ -451,6 +458,8 @@ def test_xqa_mla( batch_size, seq_len, tokens_per_page, + kv_scale, + q_scale, enable_pdl, ): set_random_seed(42) @@ -570,7 +579,7 @@ def cache_head_at( ) seq_len_list.fill_(seq_len) - kv_cache_scale = torch.ones(1, dtype=torch.float32, device="cuda") + kv_cache_scale = kv_scale nb_seq = nb_k_heads * batch_size nb_semaphores = round_up(nb_seq, 2) + 2 + nb_seq + 2 @@ -623,7 +632,7 @@ def cache_head_at( v_cache_seq=v_cache_seq, seq_len=seq_len, q_scale=q_scale * math.sqrt(576), - kv_scale=kv_cache_scale[0], + kv_scale=kv_cache_scale, x_scale=1.0, attention_sinks=None, sliding_win_size=0, From 9721ff7ff11cd537ea5c3aba61aef0e037dddf74 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Wed, 5 Nov 2025 11:19:42 -0800 Subject: [PATCH 024/130] fix: support both pip and uv pip for finding flashinfer-python package (#2043) Update getJitIncludeDirs() to try pip first, then fallback to uv pip if pip is not available. This ensures compatibility with both standard pip and uv pip package managers when locating the flashinfer-python installation for JIT compilation include paths. The command now uses shell OR operator (||) to attempt pip first, and only falls back to uv pip if the first command fails. ``` pytest -xs tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_fp8_block_scaling ============================================================================================================================================================ test session starts ============================================================================================================================================================= platform linux -- Python 3.10.12, pytest-8.4.2, pluggy-1.6.0 rootdir: /home/scratch.dmoss_gpu_1/repos/flashinfer configfile: pytest.ini collected 1 item tests/moe/test_trtllm_cutlass_fused_moe.py [TensorRT-LLM][INFO] Compiling JIT runtime gemm_swapAB_256_128_128_16_128_2_82_8_1_GroupedWithOffset with options: [TensorRT-LLM][INFO] -std=c++17 [TensorRT-LLM][INFO] --gpu-architecture=sm_90a [TensorRT-LLM][INFO] --ptxas-options=-allow-expensive-optimizations=true [TensorRT-LLM][INFO] --ptxas-options=--register-usage-level=10 [TensorRT-LLM][INFO] --diag-suppress=161,174,177,940 [TensorRT-LLM][INFO] -D__FORCE_INCLUDE_CUDA_FP16_HPP_FROM_FP16_H__=1 [TensorRT-LLM][INFO] -D__FORCE_INCLUDE_CUDA_BF16_HPP_FROM_BF16_H__=1 [TensorRT-LLM][INFO] -O3 [TensorRT-LLM][INFO] -cubin [TensorRT-LLM][INFO] --expt-relaxed-constexpr [TensorRT-LLM][INFO] --expt-extended-lambda [TensorRT-LLM][INFO] --compiler-options=-fPIC,-O3,-Wno-deprecated-declarations,-Wno-abi [TensorRT-LLM][INFO] -I/home/scratch.dmoss_gpu_1/repos/flashinfer/flashinfer/data/csrc/nv_internal/tensorrt_llm [TensorRT-LLM][INFO] [TensorRT-LLM][INFO] Generated kernel code: #ifdef __CUDACC_RTC__ #ifndef NVRTC_JIT_COMPILATION #define NVRTC_JIT_COMPILATION #endif #include #else #include #include #endif #include #include #include #include using namespace deep_gemm; using SchedulerType = typename SchedulerSelectorSwapAB::type; __global__ void dummy_kernel() { void *ptr = (void *)&fp8_gemm_kernel_swapAB<256, 128, 128, 16, 128, 2, 8, 128, 128, 1, SchedulerType, GroupedWithOffsetSchedulerInputSwapAB>; } [TensorRT-LLM][INFO] NVCC compilation took 3064 ms [TensorRT-LLM][INFO] Compilation log: [TensorRT-LLM][INFO] Successfully copied kernel files to cache directory: /home/dmoss/.tensorrt_llm/cache/gemm_swapAB_256_128_128_16_128_2_82_8_1_GroupedWithOffset [TensorRT-LLM][INFO] Compiling JIT runtime gemm_swapAB_128_128_128_16_128_2_82_8_1_GroupedWithOffset with options: [TensorRT-LLM][INFO] -std=c++17 [TensorRT-LLM][INFO] --gpu-architecture=sm_90a [TensorRT-LLM][INFO] --ptxas-options=-allow-expensive-optimizations=true [TensorRT-LLM][INFO] --ptxas-options=--register-usage-level=10 [TensorRT-LLM][INFO] --diag-suppress=161,174,177,940 [TensorRT-LLM][INFO] -D__FORCE_INCLUDE_CUDA_FP16_HPP_FROM_FP16_H__=1 [TensorRT-LLM][INFO] -D__FORCE_INCLUDE_CUDA_BF16_HPP_FROM_BF16_H__=1 [TensorRT-LLM][INFO] -O3 [TensorRT-LLM][INFO] -cubin [TensorRT-LLM][INFO] --expt-relaxed-constexpr [TensorRT-LLM][INFO] --expt-extended-lambda [TensorRT-LLM][INFO] --compiler-options=-fPIC,-O3,-Wno-deprecated-declarations,-Wno-abi [TensorRT-LLM][INFO] -I/home/scratch.dmoss_gpu_1/repos/flashinfer/flashinfer/data/csrc/nv_internal/tensorrt_llm [TensorRT-LLM][INFO] [TensorRT-LLM][INFO] Generated kernel code: #ifdef __CUDACC_RTC__ #ifndef NVRTC_JIT_COMPILATION #define NVRTC_JIT_COMPILATION #endif #include #else #include #include #endif #include #include #include #include using namespace deep_gemm; using SchedulerType = typename SchedulerSelectorSwapAB::type; __global__ void dummy_kernel() { void *ptr = (void *)&fp8_gemm_kernel_swapAB<128, 128, 128, 16, 128, 2, 8, 128, 128, 1, SchedulerType, GroupedWithOffsetSchedulerInputSwapAB>; } [TensorRT-LLM][INFO] NVCC compilation took 1479 ms [TensorRT-LLM][INFO] Compilation log: [TensorRT-LLM][INFO] Successfully copied kernel files to cache directory: /home/dmoss/.tensorrt_llm/cache/gemm_swapAB_128_128_128_16_128_2_82_8_1_GroupedWithOffset . ============================================================================================================================================================= 1 passed in 9.02s ============================================================================================================================================================== ``` ## Summary by CodeRabbit * **Bug Fixes** * Improved package detection compatibility for alternative package management tool installations. --- csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh index 9222bf19d2..25ca90927d 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -124,8 +124,9 @@ std::string getNvccCompiler() { std::vector getJitIncludeDirs() { static std::vector includeDirs; if (includeDirs.empty()) { - // Command to execute - char const* cmd = "pip show flashinfer-python 2>/dev/null"; + // Command to execute - try pip first, fallback to uv pip + char const* cmd = + "pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null"; // Buffer to store the output std::array buffer; From 747b4e286aa3075f4580cca5dc7dc811acc30614 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Wed, 5 Nov 2025 16:08:13 -0800 Subject: [PATCH 025/130] test: Fix test_sampling.py on Spark (#2042) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Current PR fixes `test_sampling.py::test_softmax` on Spark by inserting a `torch.cuda.synchronize()` before calling the softmax function. tl; dr why it works: PDL is enabled in these tests. Investigation shows that when PDL is enabled, `logits.view(-1).index_fill_(0, inf_idx, float("-inf"))` that prepares the inputs overlaps with the `probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr)` function itself. Hence, we need to ensure that the input preparation is complete before running the softmax function to get the correct output. #### Observations `test_sampling.py::test_softmax` fails on select cases Spark. Example output ``` # pytest tests/utils/test_sampling.py::test_softmax =================================================================================================================================================== test session starts =================================================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 324 items ... ================================================================================================================================================= short test summary info ================================================================================================================================================= FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution(std=1)-128256-989] - AssertionError: assert False FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution(std=5)-128256-989] - AssertionError: assert False FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-gumbel_distribution(beta=0.1)-128256-989] - AssertionError: assert False ======================================================================================================================================== 3 failed, 321 passed, 1 warning in 10.33s ``` Observations from debugging: * When outputs are printed, rows containing all `nan`s are produced in the output of `probs = flashinfer.sampling.softmax(logits)` * Surprisingly, the test passes with `CUDA_LAUNCH_BLOCKING=1 pytest tests/utils/test_sampling.py::test_softmax` * `compute-sanitizer` does not detect any IMAs * Running only a failed test results in a pass: ``` $ pytest tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution\(std=1\)-128256-989] ... 1 passed, 1 warning in 0.80s ``` Towards a fix: * I empirically find that the test passes: * when the reference `torch.softmax()` is called before `flashinfer.sampling.softmax()` (currently reference is called after) * when pdl is disabled in [line 67](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_sampling.py#L67) with `probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr,enable_pdf=False)` * when `torch.cuda.synchronize()` is inserted in the line 64 as in this PR. ``` if neg_inf_input: # assign random logits to -inf num_inf = torch.randint(0, logits.numel() - 1, (), device=logits.device).item() inf_idx = torch.randperm(logits.numel(), device=logits.device)[:num_inf] logits.view(-1).index_fill_(0, inf_idx, float("-inf")) torch.cuda.synchronize() ## This fixes the issue for some reason! if temperature_arr: temperature_arr = torch.full((batch_size,), temperature, device="cuda:0") probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr) logits_scaled = logits / temperature_arr.unsqueeze(-1) ``` but **does not fix the issue if I place the synchronization any earlier** An nsys profile shows that surprisingly the `logits.view(-1).index_fill_(0, inf_idx, float("-inf"))` and `flashinfer.sampling.softmax(logits, temperature=temperature_arr)` can overlap execution when pdl is enabled. Screenshot 2025-11-04 at 5 49 50โ€ฏPM This means that the softmax kernel is launching before inputs are done being prepared when `neg_inf_input=True`. Hence, placing a `torch.cuda.synchronize()` after the fill or disabling pdl can solve the issue. With the current PR, the nsys timeline changes to: Screenshot 2025-11-04 at 5 51 32โ€ฏPM and the unit test passes. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit ## Release Notes * **Bug Fixes** * Improved synchronization of concurrent operations to ensure proper execution order and prevent potential timing-related issues. --- tests/utils/test_sampling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index 20df72b55d..9e72c4f49b 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -61,6 +61,7 @@ def test_softmax( num_inf = torch.randint(0, logits.numel() - 1, (), device=logits.device).item() inf_idx = torch.randperm(logits.numel(), device=logits.device)[:num_inf] logits.view(-1).index_fill_(0, inf_idx, float("-inf")) + torch.cuda.synchronize() # wait for the index_fill_ to finish because it can overlap with the softmax kernel if temperature_arr: temperature_arr = torch.full((batch_size,), temperature, device="cuda:0") From adb0e89fdee0a3140a43982bc3bef4e79ce20046 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 5 Nov 2025 20:46:15 -0800 Subject: [PATCH 026/130] Fix dtype of output scales from mnnvl_moe_alltoallv_prepare_without_allgather (#2048) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description During https://github.com/flashinfer-ai/flashinfer/pull/1641 the dtype of output scales in moePrepare(mnnvl_moe_alltoallv_prepare_without_allgather) was accidently changed from float to int32. This PR fixes that. ## ๐Ÿ” Related Issues Fix https://github.com/flashinfer-ai/flashinfer/issues/2040 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Corrected tensor type validation for mixture-of-experts scale preparation so scales are validated and handled as float32, preventing type mismatches with downstream float operations. * Ensured scale tensors are created on the same device as expert identifiers, keeping tensor placement consistent across distributed processing and avoiding cross-device issues. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- csrc/trtllm_alltoall.cu | 2 +- flashinfer/comm/trtllm_alltoall.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/trtllm_alltoall.cu b/csrc/trtllm_alltoall.cu index 9ad4fff110..005fe10f1e 100644 --- a/csrc/trtllm_alltoall.cu +++ b/csrc/trtllm_alltoall.cu @@ -296,7 +296,7 @@ void moePrepareOp(TensorView expertsIds, Optional scales, CHECK_INPUT_TYPE(scales.value(), dl_float32); scalesPtr = static_cast(scales.value().data_ptr()); CHECK_DEVICE(preparedLocalScales.value(), expertsIds); - CHECK_INPUT_TYPE(preparedLocalScales.value(), dl_int32); + CHECK_INPUT_TYPE(preparedLocalScales.value(), dl_float32); TVM_FFI_ICHECK_EQ(preparedLocalScales.value().ndim(), 2); TVM_FFI_ICHECK_EQ(preparedLocalScales.value().size(0), maxTokenCountPerRank * epSize); TVM_FFI_ICHECK_EQ(preparedLocalScales.value().size(1), topK); diff --git a/flashinfer/comm/trtllm_alltoall.py b/flashinfer/comm/trtllm_alltoall.py index 114b85b30c..b6a57f41c2 100644 --- a/flashinfer/comm/trtllm_alltoall.py +++ b/flashinfer/comm/trtllm_alltoall.py @@ -236,7 +236,9 @@ def moe_prepare( ) if scales is not None: prepared_local_scales = torch.empty( - (max_token_count_per_rank * ep_size, top_k), **attrs + (max_token_count_per_rank * ep_size, top_k), + dtype=torch.float32, + device=attrs["device"], ) else: prepared_local_scales = None From b211926710c1ee8024b8383257b6b9ea71a5caa2 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 5 Nov 2025 22:06:30 -0800 Subject: [PATCH 027/130] Update trtllm-gen fused moe routing kernel and add more kernels (#1955) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description co-work with @IwakuraRein - update the trtllm-gen fused moe headers - add new kernels for trtllm-gen fused moe - for NvFp4, add tile 256 - for MxFp8 x MxFp4, add 128, 256 - for FP8 per-tensor, add 192, 256 - for FP8 block scale, add 128 - update the logics of `computeSelectedTileN` - add `tune_max_num_tokens` to FP8 per-tensor and FP8 block scale - rename `TLLM_GEN_BMM_CUBIN_PATH` to `TLLM_GEN_GEMM_CUBIN_PATH` - add `TLLM_GEN_EXPORT_FLASHINFER` **NOTE: split-k kernels are temporarily disabled as they cause failure in renormalize + expert 256 tests.** ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Expanded MoE tiling (adds 128/192/256), FP8 perโ€‘tensor MoE path, FP8/FP4 autotuner benchmark, and new tune_max_num_tokens tuning parameter. * **Improvements** * Router now supports tileโ€‘based (nonโ€‘powerโ€‘ofโ€‘two) layouts and propagates explicit valid M/N/K for safer sizing; autotuner logs include exception details; added export/compile flags and clearer kernel error messages. * **Bug Fixes** * Relaxed strict padding/powerโ€‘ofโ€‘two checks and made log2 handling safer. * **Tests** * Extended MoE tests to cover new FP8 blockโ€‘scale and routing scenarios. --------- Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: Siyuan Fu Co-authored-by: Siyuan Fu --- .../bench_trtllm_gen_fused_moe_autotuner.py | 142 ++++- csrc/trtllm_batched_gemm_runner.cu | 19 +- csrc/trtllm_fused_moe_kernel_launcher.cu | 42 +- csrc/trtllm_fused_moe_routing_deepseek.cu | 38 +- csrc/trtllm_fused_moe_routing_llama4.cu | 53 +- csrc/trtllm_fused_moe_routing_renormalize.cu | 38 +- csrc/trtllm_fused_moe_runner.cu | 8 +- flashinfer/artifacts.py | 6 +- flashinfer/autotuner.py | 2 +- flashinfer/fused_moe/core.py | 12 +- flashinfer/jit/fused_moe.py | 3 +- flashinfer/jit/gemm/core.py | 2 + .../trtllmGen_bmm_export/BatchedGemmEnums.h | 8 +- .../BatchedGemmInterface.h | 531 +++++++++--------- .../trtllmGen_bmm_export/BatchedGemmOptions.h | 174 +++--- .../GemmGatedActOptions.h | 51 +- .../trtllmGen_bmm_export/GemmOptions.h | 270 +++++++-- .../trtllmGen_bmm_export/KernelParams.h | 78 +-- .../trtllmGen_bmm_export/KernelParamsDecl.h | 48 -- .../trtllmGen_bmm_export/KernelTraits.h | 31 +- .../trtllmGen_bmm_export/TmaDescriptor.h | 4 +- .../trtllm/gen/CommonUtils.h | 2 - .../trtllm/gen/SfLayoutDecl.h | 2 - .../flashinfer/trtllm/fused_moe/DevKernel.h | 115 ++-- .../trtllm/fused_moe/RoutingKernel.cuh | 90 ++- .../trtllm/fused_moe/RoutingKernel.h | 22 +- tests/moe/test_trtllm_gen_fused_moe.py | 151 +++-- 27 files changed, 1235 insertions(+), 707 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 2a991829dd..e7e40e772f 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -8,13 +8,109 @@ fp4_quantize, mxfp8_quantize, ) -from flashinfer.fused_moe import trtllm_fp4_block_scale_moe +from flashinfer.fused_moe import ( + trtllm_fp4_block_scale_moe, + trtllm_fp8_per_tensor_scale_moe, +) from flashinfer.autotuner import autotune from flashinfer.testing.utils import bench_gpu_time from flashinfer.utils import device_support_pdl +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +FLOAT4_E2M1_MAX = 6.0 + + +def fp8_quantize(x): + max = x.float().abs().nan_to_num().max() + scale = FLOAT8_E4M3_MAX / max + x = (x * scale).to(torch.float8_e4m3fn) + return x, 1.0 / scale -def bench_trtllm_gen_fused_moe_autotuner( + +def bench_trtllm_gen_fused_moe_autotuner_fp8( + tune_max_num_tokens: Optional[int], + quant_mode: Literal["Fp8-Per-Tensor"], + num_tokens: int, + num_experts: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + warmups: int, + iterations: int, +): + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( + torch.bfloat16 + ) + hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( + torch.bfloat16 + ) + w13 = torch.randn( + num_experts, intermediate_size * 2, hidden_size, device=device + ).to(torch.bfloat16) + w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( + torch.bfloat16 + ) + + hidden_states, hidden_states_scale = fp8_quantize(hidden_states) + w13, w13_scale = fp8_quantize(w13) + w2, w2_scale = fp8_quantize(w2) + + output1_scale_scalar = torch.tensor( + [hidden_states_scale * w13_scale] * num_experts, device=device + ) + output1_scales_gate_scalar = torch.ones( + num_experts, device=device, dtype=torch.float32 + ) + output2_scale_scalar = torch.tensor( + [hidden_states_scale * w2_scale] * num_experts, device=device + ) + + fn = lambda: trtllm_fp8_per_tensor_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + w13, + output1_scale_scalar, + output1_scales_gate_scalar, + w2, + output2_scale_scalar, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + 1.0, # routed_scaling_factor + False, # use_routing_scales_on_input + None, + RoutingMethodType.TopK.value, + enable_pdl, + num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + ) + + def bench(do_autotune): + with autotune(do_autotune): + fn() + ms_list = bench_gpu_time( + fn, + dry_run_iters=warmups, + repeat_iters=iterations, + ) + median_ms = np.median(ms_list) + return median_ms + + ms = bench(do_autotune=False) + ms_tuned = bench(do_autotune=True) + print( + f"num tokens: {num_tokens}, num experts: {num_experts}, hidden size: {hidden_size}, intermediate size: {intermediate_size}, top k: {top_k}" + ) + print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") + + +def bench_trtllm_gen_fused_moe_autotuner_fp4( tune_max_num_tokens: Optional[int], quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], num_tokens: int, @@ -143,12 +239,11 @@ def bench_trtllm_gen_fused_moe_autotuner( ) def bench(do_autotune): - # warmup with autotune(do_autotune): - for _ in range(warmups): - fn() + fn() ms_list = bench_gpu_time( fn, + dry_run_iters=warmups, repeat_iters=iterations, ) median_ms = np.median(ms_list) @@ -168,7 +263,7 @@ def bench(do_autotune): "--quant-mode", type=str, default="MxFP4xMxFP8", - choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], + choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16", "Fp8-Per-Tensor"], help="Quantization mode", ) parser.add_argument("--num-tokens", type=int, default=512, help="Number of tokens") @@ -193,14 +288,27 @@ def bench(do_autotune): "--iterations", type=int, default=100, help="Number of benchmark iterations" ) args = parser.parse_args() - bench_trtllm_gen_fused_moe_autotuner( - args.tune_max_num_tokens, - args.quant_mode, - args.num_tokens, - args.num_experts, - args.hidden_size, - args.intermediate_size, - args.top_k, - args.warmups, - args.iterations, - ) + if args.quant_mode == "Fp8-Per-Tensor": + bench_trtllm_gen_fused_moe_autotuner_fp8( + args.tune_max_num_tokens, + args.quant_mode, + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.warmups, + args.iterations, + ) + else: + bench_trtllm_gen_fused_moe_autotuner_fp4( + args.tune_max_num_tokens, + args.quant_mode, + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.warmups, + args.iterations, + ) diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index bf57fd5b9e..42fe8f7f59 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -144,6 +144,10 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes( gemmData.mProblemDimensions.mWorldSize = 1; gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + auto bmm = BatchedGemmInterface(); auto const configs = bmm.getBatchedGemmConfigs(); @@ -239,6 +243,10 @@ void TrtllmGenBatchedGemmRunner::run( int32_t multiProcessorCount; cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device); + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere bmm.runInitBeforeWorldSync(config, gemmData, static_cast(stream)); @@ -327,6 +335,10 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( gemmData.mProblemDimensions.mWorldSize = 1; gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) { auto const& optionsA = configs[idx0].mOptions; auto const& optionsB = configs[idx1].mOptions; @@ -387,8 +399,7 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( // Filter out invalid configs. std::vector validConfigIndices; for (auto const& configIndex : prioritizedIndices) { - auto const& config = configs[configIndex]; - auto isValidConfig = bmm.isValidConfig(config, gemmData); + auto isValidConfig = bmm.isValidConfig(configs[configIndex], gemmData); if (isValidConfig) { validConfigIndices.push_back(configIndex); } @@ -435,7 +446,9 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t auto const& config = configs[configIndex]; - return bmm.isValidConfig(config, gemmData); + // FIXME: temporarily disable split-k as renormalize routing plus expert number 256 failed in + // trtllm-gen ac83afb + return bmm.isValidConfig(config, gemmData) && config.mOptions.mClusterDimZ == 1; } } // namespace kernels diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 538dc92725..3fd9dab35e 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -63,13 +63,22 @@ std::set computeSelectedTileN(std::vector const& supported_til int64_t const num_tokens, int64_t const top_k, int64_t const num_local_experts) { float const avg_tokens_per_expert = static_cast(num_tokens * top_k) / num_local_experts; + // assume supported_tile_nums is sorted int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), supported_tile_nums.front(), supported_tile_nums.back()); - - std::set selected_tile_nums = { - std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim, - std::min(supported_tile_nums.back(), tile_tokens_dim * 2), - std::min(supported_tile_nums.back(), tile_tokens_dim * 4)}; + auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); + + std::set selected_tile_nums; + selected_tile_nums.insert(tile_tokens_dim); + if (std::next(it) != supported_tile_nums.end()) { + selected_tile_nums.insert(*std::next(it)); + if (std::next(std::next(it)) != supported_tile_nums.end()) { + selected_tile_nums.insert(*std::next(std::next(it))); + } + } + if (it != supported_tile_nums.begin()) { + selected_tile_nums.insert(*std::prev(it)); + } return selected_tile_nums; } @@ -369,7 +378,7 @@ void trtllm_fp8_per_tensor_scale_moe( auto const hidden_size = hidden_states.size(1); bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8 - std::vector mSupportedTileN = {8, 16, 32, 64, 128}; + std::vector mSupportedTileN = {8, 16, 32, 64, 128, 192, 256}; std::set selected_tile_nums = computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); @@ -718,7 +727,7 @@ void trtllm_fp8_block_scale_moe( auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); - std::vector mSupportedTileN = {8, 16, 32, 64}; + std::vector mSupportedTileN = {8, 16, 32, 64, 128}; std::set selected_tile_nums = computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); @@ -1228,6 +1237,11 @@ Array trtllm_fp4_block_scale_moe( if (mDtypeAct != btg::Dtype::Bfloat16) { mSupportedTileN.push_back(128); } + if ((mDtypeAct == btg::Dtype::MxE4m3 && mDtypeWeights == btg::Dtype::MxE2m1) || + (mDtypeAct == btg::Dtype::E2m1 && mDtypeWeights == btg::Dtype::E2m1)) { + // MxFP4 x MxFP4 or NvFP4 x NvFP4 + mSupportedTileN.push_back(256); + } std::set selected_tile_nums = computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); // Build runners for all supported tile sizes @@ -1305,8 +1319,20 @@ Array> trtllm_get_valid_moe_configs( bool is_fp8_per_tensor = dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8; - if (is_fp4_without_bf16_act || is_fp8_per_tensor) { + if (useDeepSeekFp8) { + supported_tile_nums.push_back(128); + } else if (is_fp8_per_tensor) { supported_tile_nums.push_back(128); + supported_tile_nums.push_back(192); + supported_tile_nums.push_back(256); + } else if (is_fp4_without_bf16_act) { + supported_tile_nums.push_back(128); + } + + if ((dtype_act == btg::Dtype::MxE4m3 && dtype_weights == btg::Dtype::MxE2m1) || + (dtype_act == btg::Dtype::E2m1 && dtype_weights == btg::Dtype::E2m1)) { + // MxFP4 x MxFP4 or NvFP4 x NvFP4 + supported_tile_nums.push_back(256); } std::set selected_tile_nums = computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 527924559d..7f9a664291 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -392,7 +392,14 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // Compute the runtime config for projections // Whether or not an expert is local is taken into account when smemExpertCount is computed // so we do not need to take it into account here. - const int32_t numCta = divUpLog2(count, params.mPaddingLog2); + + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(count, params.mPaddingLog2); + } else { + numCta = divUpTileN(count, params.mTileTokensDim); + } + int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); @@ -401,14 +408,31 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = - min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffset, params.mPaddingLog2) + count); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } // get the padded offset associated with this expert - const int32_t offset = mulLog2(ctaOffset, params.mPaddingLog2); - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t offset; + if constexpr (KernelParams::isPow2) { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } // write out padded count if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) { @@ -542,8 +566,6 @@ void runImpl(Data& data, void* stream) { } FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); - FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", - data.mPaddingLog2); int const numBlocks = data.mNumTokens; int const numThreadsHist = getMaxNumExperts(data.mNumExperts); diff --git a/csrc/trtllm_fused_moe_routing_llama4.cu b/csrc/trtllm_fused_moe_routing_llama4.cu index ebdd0b8720..13ca041644 100644 --- a/csrc/trtllm_fused_moe_routing_llama4.cu +++ b/csrc/trtllm_fused_moe_routing_llama4.cu @@ -189,7 +189,13 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { auto count = getBits(expertCount, ii); - numCta += divUpLog2(count, params.mPaddingLog2); + int32_t num; + if constexpr (KernelParams::isPow2) { + num = divUpLog2(count, params.mPaddingLog2); + } else { + num = divUpTileN(count, params.mTileTokensDim); + } + numCta += num; } // second, we perform the exclusive sum across the warp int32_t ctaOffset; @@ -202,22 +208,39 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { auto count = getBits(expertCount, ii); - auto finalNumCta = divUpLog2(count, params.mPaddingLog2); + int32_t finalNumCta; + if constexpr (KernelParams::isPow2) { + finalNumCta = divUpLog2(count, params.mPaddingLog2); + } else { + finalNumCta = divUpTileN(count, params.mTileTokensDim); + } auto expertIdx = threadIdx.x * ExpertsPerThread + ii; // during the scan for expert offsets, we can already write out // both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit` for (int cta = 0; cta < finalNumCta; ++cta) { params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] = - min(mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffsetExp, params.mPaddingLog2) + count); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffsetExp, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffsetExp + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffsetExp, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] = min(mnLimit1, mnLimit2); } ctaOffsetExp += finalNumCta; } // at this point, we can write out padded count from the warp-aggregate if (cute::elect_one_sync()) { - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } @@ -236,12 +259,20 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam // of registers auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; int32_t finalExpertOffset[ExpertsPerThread]; - finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); + if constexpr (KernelParams::isPow2) { + finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + finalExpertOffset[0] = mulTileN(ctaOffset, params.mTileTokensDim); + } #pragma unroll for (int ii = 1; ii < ExpertsPerThread; ++ii) { - finalExpertOffset[ii] = - finalExpertOffset[ii - 1] + - divUpMulLog2(getBits(expertCount, ii - 1), params.mPaddingLog2); + int32_t tmp; + if constexpr (KernelParams::isPow2) { + tmp = divUpMulLog2(getBits(expertCount, ii - 1), params.mPaddingLog2); + } else { + tmp = divUpMulTileN(getBits(expertCount, ii - 1), params.mTileTokensDim); + } + finalExpertOffset[ii] = finalExpertOffset[ii - 1] + tmp; } #pragma unroll @@ -455,8 +486,6 @@ void runImpl(Data const& data, void* stream) { NumExpertsLimit); FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); - FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", - data.mPaddingLog2); bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens) || diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 1a4823d481..56939f8d02 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -165,14 +165,24 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } __syncthreads(); // Get the number of CTAs and the offset for each CTA - const int32_t numCta = divUpLog2(accExpertCount, params.mPaddingLog2); + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(accExpertCount, params.mPaddingLog2); + } else { + numCta = divUpTileN(accExpertCount, params.mTileTokensDim); + } int32_t ctaOffset = 0; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); int32_t expertScanCounts = 0; - Scan(tempStorage) - .ExclusiveSum(divUpMulLog2(accExpertCount, params.mPaddingLog2), expertScanCounts); + int32_t tmpCount; + if constexpr (KernelParams::isPow2) { + tmpCount = divUpMulLog2(accExpertCount, params.mPaddingLog2); + } else { + tmpCount = divUpMulTileN(accExpertCount, params.mTileTokensDim); + } + Scan(tempStorage).ExclusiveSum(tmpCount, expertScanCounts); __syncthreads(); if (isLocalExpert) { @@ -180,15 +190,27 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) const int32_t localExpertIdx = (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = - min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffset, params.mPaddingLog2) + accExpertCount); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + accExpertCount; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + accExpertCount; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } } // at this point, we can write out padded count if (threadIdx.x == 0) { - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } @@ -399,8 +421,6 @@ void run(Data const& data, void* stream) { << NumExpertsLimit << "."; TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0) << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; - TVM_FFI_ICHECK_LE(data.mPaddingLog2, 8) - << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2; bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index a33843516e..21a2cad4b5 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -23,7 +23,6 @@ #include "flashinfer/trtllm/fused_moe/DevKernel.h" #include "flashinfer/trtllm/fused_moe/RoutingKernel.h" #include "flashinfer/trtllm/fused_moe/runner.h" -// #include namespace tensorrt_llm { namespace kernels { @@ -39,7 +38,9 @@ inline int32_t computeLog2(int32_t val, std::string const& name = "") { while (n >>= 1) { ++out; } - FLASHINFER_CHECK((1 << out) == val, "Expected ", name, " to be a power of 2, got ", val); + if ((1 << out) != val) { + out = -1; + } return out; } } // namespace @@ -90,6 +91,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumLimitedGroups = topkGroup; routingData.mTopK = topK; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; @@ -124,6 +126,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumExperts = numExperts; routingData.mTopK = topK; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; @@ -170,6 +173,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumExperts = numExperts; routingData.mTopK = topK; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 25f679968f..733b7aed24 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -89,7 +89,7 @@ class ArtifactPath: TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23" + "23daeee32b60bde7947ce1ee7a58d4ab701f134b/batched_gemm-0d28130-add42d1" ) TRTLLM_GEN_GEMM: str = ( "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" @@ -105,7 +105,7 @@ class MetaInfoHash: "2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a" ) TRTLLM_GEN_BMM: str = ( - "4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152" + "6cfade1395f9648aba5dcf2c329114619e175c0f238882555178f98c8f5c1968" ) TRTLLM_GEN_GEMM: str = ( "bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340" @@ -123,7 +123,7 @@ class CheckSumHash: "639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f" ) TRTLLM_GEN_BMM: str = ( - "8df2aae8f3aa39d64d2c723e775640beb4ac602a6cbb02e497c2a7316e349934" + "46ccf0492e3ed10135c2861a4f4ef9bb45846610f9a9d2ccaf2d5bf01d2006fd" ) DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" TRTLLM_GEN_GEMM: str = ( diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index f8af220916..a82fabd8c0 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -483,7 +483,7 @@ def choose_one( except Exception as e: shapes = self._get_input_sizes(tensors) logger.warning( - f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling." + f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling: {e}" ) # Log stacktrace as debug to not spam log diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index c91878ca0e..3ea148c780 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1676,9 +1676,10 @@ def trtllm_fp8_per_tensor_scale_moe( local_num_experts: int, routed_scaling_factor: Optional[float], use_routing_scales_on_input: bool, - tile_tokens_dim: int = 8, + tile_tokens_dim: Optional[int] = None, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ) -> torch.Tensor: """FP8 per tensor scale MoE operation. @@ -1700,9 +1701,10 @@ def trtllm_fp8_per_tensor_scale_moe( local_num_experts: Number of experts handled by this device routed_scaling_factor: Scaling factor for routing use_routing_scales_on_input: Whether to use routing scales on input - tile_tokens_dim: Tile dimension for tokens (default: 8) + tile_tokens_dim: Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. + tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] @@ -1733,6 +1735,7 @@ def trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input, routing_method_type, enable_pdl, + tune_max_num_tokens, ) @@ -1758,6 +1761,7 @@ def trtllm_fp8_block_scale_moe( use_shuffled_weight: bool = False, weight_layout: int = 0, enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ) -> torch.Tensor: """FP8 block scale MoE operation. @@ -1778,9 +1782,10 @@ def trtllm_fp8_block_scale_moe( local_expert_offset: Offset of local experts in global expert space local_num_experts: Number of experts handled by this device routed_scaling_factor: Scaling factor for routing - tile_tokens_dim: Tile dimension for tokens (default: 8) + tile_tokens_dim: Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. + tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ @@ -1815,6 +1820,7 @@ def trtllm_fp8_block_scale_moe( use_shuffled_weight, weight_layout, enable_pdl, + tune_max_num_tokens, ) diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 11398fabd9..78c19e98ac 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -233,11 +233,12 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: ], extra_cuda_cflags=[ "-DTLLM_GEN_EXPORT_INTERFACE", + "-DTLLM_GEN_EXPORT_FLASHINFER", "-DTLLM_ENABLE_CUDA", "-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP4", - f'-DTLLM_GEN_BMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"', + f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"', ] + nvcc_flags, extra_include_paths=[ diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 6564aefa35..7873d0de14 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -381,6 +381,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec: ], extra_cuda_cflags=[ "-DTLLM_GEN_EXPORT_INTERFACE", + "-DTLLM_GEN_EXPORT_FLASHINFER", "-DTLLM_ENABLE_CUDA", f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"', ] @@ -531,6 +532,7 @@ def gen_trtllm_low_latency_gemm_module() -> JitSpec: ], extra_cuda_cflags=[ "-DTLLM_GEN_EXPORT_INTERFACE", + "-DTLLM_GEN_EXPORT_FLASHINFER", "-DTLLM_ENABLE_CUDA", f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"', ] diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h index 27955d2bdc..919d6cb00d 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h @@ -31,7 +31,9 @@ enum class RouteImpl { // Use LDGSTS to do the routing Ldgsts = 1, // Use UTMALDG.GATHER4 to do the routing - Tma = 2 + Tma = 2, + // Use LDG+STS to do the routing + LdgPlusSts = 3 }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,6 +50,10 @@ inline bool doesRouteImplUseTma(RouteImpl mode) { return (mode == RouteImpl::Tma //////////////////////////////////////////////////////////////////////////////////////////////////// +inline bool doesRouteImplUseLdgPlusSts(RouteImpl mode) { return (mode == RouteImpl::LdgPlusSts); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace batchedGemm //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h index 6b1f910178..f93f20d28e 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h @@ -24,18 +24,12 @@ #include "trtllm/gen/CudaKernelLauncher.h" #ifdef TLLM_GEN_EXPORT_INTERFACE +#ifdef TLLM_GEN_EXPORT_FLASHINFER #include "flashinferMetaInfo.h" -#endif // TLLM_GEN_EXPORT_INTERFACE - -#ifdef TLLM_GEN_BMM_CUBIN_PATH -static const std::string tllm_gen_bmm_cubin_path = std::string(TLLM_GEN_BMM_CUBIN_PATH); #else -static_assert(false, "TLLM_GEN_BMM_CUBIN_PATH macro is not defined when compiling"); -#endif - -namespace flashinfer::trtllm_cubin_loader { -std::string getCubin(const std::string& kernelName, const std::string& sha256); -} +#include "KernelMetaInfo.h" +#endif // TLLM_GEN_EXPORT_FLASHINFER +#endif // TLLM_GEN_EXPORT_INTERFACE namespace batchedGemm { @@ -79,13 +73,18 @@ struct BatchedGemmData { // The M dimension. // It is the total number of tokens if A is the activation matrix. // It is the total number of output channels if A is the weight matrix. + // ValidM/N/K by default assumes to be full range of M/N/K respectively. If we pad M/N/K due to + // alignment of other constraints, then we can specify ValidM/N/K to indicate the valid range. int32_t mM{0}; + int32_t mValidM{0}; // The N dimension. // It is the total number of tokens if B is the activation matrix. // It is the total number of output channels if B is the weight matrix. int32_t mN{0}; + int32_t mValidN{0}; // The K dimension. It is the hidden dimension of the input matrices. int32_t mK{0}; + int32_t mValidK{0}; // The rank id of the current device in the multi-gpu space. int32_t mRank{0}; // The number of devices in tensor-parallel group. @@ -457,28 +456,187 @@ class BatchedGemmInterface { public: using ModuleCache = std::unordered_map>; - BatchedGemmInterface() {} + ////////////////////////////////////////////////////////////////////////////////////////////////// + + BatchedGemmInterface(bool const exportsCubin = false, int32_t const numRotations = 1) + : mExportsCubin(exportsCubin), mNumRotations(numRotations) {} + + ////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifndef TLLM_GEN_EXPORT_INTERFACE + // Generates and compiles the kernel using either nvcc or nvrtc. + BatchedGemmConfig generateAndCompileKernel(BatchedGemmConfig const& batchedGemmConfig) const; +#endif + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Provided config must be validated with isValidConfig before the call. - int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, - void* cudaStream, int32_t multiProcessorCount, bool usePdl = true, - std::optional> moduleCache = std::nullopt); + int32_t run(BatchedGemmConfig const& config, void* workspace, + BatchedGemmData const& batchedGemmData, void* cudaStream, + int32_t /*multiProcessorCount*/, bool usePdl = true, + std::optional> moduleCache = std::nullopt) { + // Might be used. + (void)usePdl; + (void)moduleCache; + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, batchedGemmData); + + bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; + bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 && + options.mDtypeB == tg::Dtype::E4m3; + + auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData); + float* dPtrRowMax{nullptr}; + uint32_t* dPtrRowMaxBars{nullptr}; + + // Set the completion barriers to 0 if needed. + if (useDeepSeekFp8 && options.mFusedAct) { + dPtrRowMax = reinterpret_cast(alignPtr(reinterpret_cast(workspace), 1024)); + dPtrRowMaxBars = reinterpret_cast( + alignPtr(reinterpret_cast(dPtrRowMax) + workspaceSizes[0], 1024)); + auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1], + reinterpret_cast(cudaStream)); + if (err != cudaSuccess) { + return 1; + } + } + + auto [numCtaBatch, numCtaTile, numCtaInner] = + getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); + auto kernelParams = KernelParamsSetup::setKernelParams( + options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, + batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, + batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, + batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, + batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, + batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit, + batchedGemmData.mInputBuffers.mPtrGatedActAlpha, + batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, + dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, + batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); + + // The size of the grid. + std::vector grid = batchM ? std::vector{numCtaBatch, numCtaTile, numCtaInner} + : std::vector{numCtaTile, numCtaBatch, numCtaInner}; + + BatchedGemmConfig batchedGemmConfig = config; +#ifndef TLLM_GEN_EXPORT_INTERFACE + // Generate and compile the kernel if data is not provided. + if (config.mData == nullptr) { + batchedGemmConfig = generateAndCompileKernel(batchedGemmConfig); + } + TLLM_CHECK_ERROR(batchedGemmConfig.mCudaRunner != nullptr, "CudaRunner is not set"); + batchedGemmConfig.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid, + /* cluster */ {}, + /* instanceId */ batchedGemmConfig.mInstanceIdx); + return 0; +#endif + + CUmodule cuModule; + CUfunction cuFunction; + + if (moduleCache.has_value()) { + ModuleCache& moduleCacheRef = moduleCache.value().get(); + + // Modules are associated with a specific context, so the context is included in the key + CUcontext ctx; + unsigned long long ctxId; + cuCtxGetCurrent(&ctx); + cuCtxGetId(ctx, &ctxId); + + // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a + // string in decimal representation. + std::string const ctxName = + std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); + std::string const funcName = std::string(batchedGemmConfig.mFunctionName); + auto const moduleKey = ctxName + funcName; + auto module = moduleCacheRef.find(moduleKey); + + // Use cache if module is found, otherwise load and insert into cache + if (module != moduleCacheRef.end()) { + cuFunction = std::get<1>(module->second); + } else { + gemm::loadCubinData(&cuModule, batchedGemmConfig); + cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); + moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); + } + } else { + gemm::loadCubinData(&cuModule, batchedGemmConfig); + cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); + } + + // Prepare the grid/block. + dim3 block3{static_cast(batchedGemmConfig.mNumThreadsPerCTA), + static_cast(1), static_cast(1)}; + dim3 grid3{(grid.size() > 0 ? static_cast(grid[0]) : 1u), + (grid.size() > 1 ? static_cast(grid[1]) : 1u), + (grid.size() > 2 ? static_cast(grid[2]) : 1u)}; + // Prepare the cluster size. + dim3 cluster3{static_cast(options.mClusterDimX), + static_cast(options.mClusterDimY), + static_cast(options.mClusterDimZ)}; + + // Run the kernel. + auto result = trtllm::gen::launchKernel( + (void*)&kernelParams, cudaStream, batchedGemmConfig.mSharedMemSize, cuFunction, block3, + grid3, cluster3, + usePdl && (batchedGemmConfig.mOptions.mGridWaitForPrimaryEarlyExit | + batchedGemmConfig.mOptions.mGridWaitForPrimaryA | + batchedGemmConfig.mOptions.mGridWaitForPrimaryB)); + if (result != CUDA_SUCCESS) { + return -1; + } + // If a module cache has not been given, unload the module to avoid leaking + if (!moduleCache.has_value()) { + cuModuleUnload(cuModule); + } + return 0; + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Initializes the buffers before the world sync. Must be called before run. int32_t runInitBeforeWorldSync(BatchedGemmConfig const& /* config */, BatchedGemmData const& /* data */, void* /* cudaStream */) const { return 0; - }; + } - size_t getWorkspaceSizeInBytes(BatchedGemmConfig const& /* config */, - BatchedGemmData const& /* data */) const; + ////////////////////////////////////////////////////////////////////////////////////////////////// + + size_t getWorkspaceSizeInBytes(BatchedGemmConfig const& config, + BatchedGemmData const& data) const { + auto workspaceSizes = getWorkspaceSizesInBytes(config, data); + auto size = std::accumulate(workspaceSizes.begin(), workspaceSizes.end(), 0); + // Additional 1023 bytes to align the pointer to 1024 + return size > 0 ? size + 1023 : 0; + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Returns the list of all available cubin configurations - BatchedGemmConfig const* getBatchedGemmConfigs() const; + BatchedGemmConfig const* getBatchedGemmConfigs() const { +#ifdef TLLM_GEN_EXPORT_INTERFACE + return tensorrt_llm::kernels::tllmGenBatchedGemmList; +#else + return nullptr; +#endif + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Returns the number of available cubin configurations - size_t getNumBatchedGemmConfigs() const; + size_t getNumBatchedGemmConfigs() const { +#ifdef TLLM_GEN_EXPORT_INTERFACE + return tensorrt_llm::kernels::tllmGenBatchedGemmListLen; +#else + return 0; +#endif + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Returns the grid dimensions of the current kernel. std::tuple getGridDim( @@ -523,6 +681,8 @@ class BatchedGemmInterface { return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner); } + ////////////////////////////////////////////////////////////////////////////////////////////////// + // Returns the number of CTAs of the current kernel. int32_t getNumCtas(BatchedGemmOptions const& options, std::optional maxNumCtasInBatchDim = std::nullopt) const { @@ -530,278 +690,117 @@ class BatchedGemmInterface { return numCtasBatch * numCtasTile * numCtasInner; } - // Returns true if the configuration of the cubin can be executed for the given params. - bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const; + ////////////////////////////////////////////////////////////////////////////////////////////////// // Creates GemmOptions from kernel and data. BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, - BatchedGemmData const& data) const; - - private: - // Aligns the pointer to the alignment - template - inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const; - - // Returns the size of the workspace buffers in bytes - std::vector getWorkspaceSizesInBytes(BatchedGemmConfig const& config, - BatchedGemmData const& data) const; - - // Returns the size padded to the alignment - size_t getSizePaddedToAlignment(size_t size, size_t alignment) const; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline Dtype* BatchedGemmInterface::alignPtr(Dtype* ptr, int64_t alignment) const { - assert((alignment & (alignment - 1)) == 0 && "Alignment must be a power of 2"); - return reinterpret_cast((reinterpret_cast(ptr) + alignment - 1) & - ~(alignment - 1)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -BatchedGemmConfig const* BatchedGemmInterface::getBatchedGemmConfigs() const { -#ifdef TLLM_GEN_EXPORT_INTERFACE - return tensorrt_llm::kernels::tllmGenBatchedGemmList; -#else - return nullptr; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const { -#ifdef TLLM_GEN_EXPORT_INTERFACE - return sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList) / - sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList[0]); -#else - return 0; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -BatchedGemmOptions BatchedGemmInterface::getOptionsFromConfigAndData( - BatchedGemmConfig const& config, BatchedGemmData const& data) const { - // Create options from config and data. - BatchedGemmOptions options; - options = config.mOptions; - options.mM = data.mProblemDimensions.mM; - options.mN = data.mProblemDimensions.mN; - options.mK = data.mProblemDimensions.mK; - options.mBatchedM = data.mProblemDimensions.mBatchedM; - options.mBatchedN = data.mProblemDimensions.mBatchedN; - options.mBatchMode = data.mProblemDimensions.mBatchM ? BatchedGemmOptions::BatchMode::BatchM - : BatchedGemmOptions::BatchMode::BatchN; - options.mNumBatches = data.mProblemDimensions.mNumBatches; - options.mNumTokens = data.mProblemDimensions.mNumTokens; - return options; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -bool BatchedGemmInterface::isValidConfig(BatchedGemmConfig const& config, - BatchedGemmData const& data) const { - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, data); - - // Is Blackwell? - bool isBlackwell = gemm::isSmVersionBlackwell(config.mSm); + BatchedGemmData const& data) const { + BatchedGemmOptions options; + options = config.mOptions; + options.mM = data.mProblemDimensions.mM; + options.mN = data.mProblemDimensions.mN; + options.mK = data.mProblemDimensions.mK; + options.mValidM = data.mProblemDimensions.mValidM; + options.mValidN = data.mProblemDimensions.mValidN; + options.mValidK = data.mProblemDimensions.mValidK; + options.mBatchedM = data.mProblemDimensions.mBatchedM; + options.mBatchedN = data.mProblemDimensions.mBatchedN; + options.mBatchMode = data.mProblemDimensions.mBatchM ? BatchedGemmOptions::BatchMode::BatchM + : BatchedGemmOptions::BatchMode::BatchN; + options.mNumBatches = data.mProblemDimensions.mNumBatches; + options.mNumTokens = data.mProblemDimensions.mNumTokens; + return options; + } - // Check options without modifications. - return checkAndUpdateBatchedGemmOptions(options, isBlackwell, - /* updateOptions */ false); -} + ////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Returns true if the configuration of the cubin can be executed for the given params. + bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const { + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, data); -size_t BatchedGemmInterface::getSizePaddedToAlignment(size_t size, size_t alignment) const { - assert((alignment & (alignment - 1)) == 0); - return (size + alignment - 1) & ~(alignment - 1); -} + // Is Blackwell? + bool isBlackwell = gemm::isSmVersionBlackwell(config.mSm); -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Check options without modifications. + return checkAndUpdateBatchedGemmOptions(options, isBlackwell, + /* updateOptions */ false); + } -size_t BatchedGemmInterface::getWorkspaceSizeInBytes(BatchedGemmConfig const& config, - BatchedGemmData const& data) const { - auto workspaceSizes = getWorkspaceSizesInBytes(config, data); - auto size = std::accumulate(workspaceSizes.begin(), workspaceSizes.end(), 0); - // Additional 1023 bytes to align the pointer to 1024 - return size > 0 ? size + 1023 : 0; -} + ////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// + private: + ////////////////////////////////////////////////////////////////////////////////////////////////// -std::vector BatchedGemmInterface::getWorkspaceSizesInBytes( - BatchedGemmConfig const& config, BatchedGemmData const& data) const { - std::vector workspaceSizes; + template + inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const { + assert((alignment & (alignment - 1)) == 0 && "Alignment must be a power of 2"); + return reinterpret_cast((reinterpret_cast(ptr) + alignment - 1) & + ~(alignment - 1)); + } - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, data); + ////////////////////////////////////////////////////////////////////////////////////////////////// - if (options.mUseDeepSeekFp8 && options.mFusedAct) { - int32_t totalNumPaddedTokens = 0; - auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; - if (!options.mEnablesEarlyExit || options.mNumTokens == 0) { - for (int32_t bi = 0; bi < options.mNumBatches; ++bi) { - totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM) - : gemm::divUpMul(options.mBatchedN[bi], options.mTileN); + // Returns the size of the workspace buffers in bytes + std::vector getWorkspaceSizesInBytes(BatchedGemmConfig const& config, + BatchedGemmData const& data) const { + std::vector workspaceSizes; + + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, data); + + if (options.mUseDeepSeekFp8 && options.mFusedAct) { + int32_t totalNumPaddedTokens = 0; + auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; + if (!options.mEnablesEarlyExit || options.mNumTokens == 0) { + for (int32_t bi = 0; bi < options.mNumBatches; ++bi) { + totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM) + : gemm::divUpMul(options.mBatchedN[bi], options.mTileN); + } + } else { + // Get tile in token dim. + auto tileTokensDim = batchM ? options.mTileM : options.mTileN; + totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; } - } else { - // Get tile in token dim. - auto tileTokensDim = batchM ? options.mTileM : options.mTileN; - totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; - } - - // Get options from config. - auto& options = config.mOptions; - int const tokenTile = batchM ? options.mTileM : options.mTileN; + // Get options from config. + auto& options = config.mOptions; - auto const numTokens = totalNumPaddedTokens; - auto const intermediateDim = batchM ? options.mN : options.mM; - auto const intermediateTile = batchM ? options.mTileN : options.mTileM; + int const tokenTile = batchM ? options.mTileM : options.mTileN; - auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float); + auto const numTokens = totalNumPaddedTokens; + auto const intermediateDim = batchM ? options.mN : options.mM; + auto const intermediateTile = batchM ? options.mTileN : options.mTileM; - auto const numTilesToken = numTokens / tokenTile; - auto const numTilesInt = intermediateDim / intermediateTile; - auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t); - - // TODO: do we need to pad to 1024? - workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024)); - workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024)); - } + auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float); - return workspaceSizes; -} + auto const numTilesToken = numTokens / tokenTile; + auto const numTilesInt = intermediateDim / intermediateTile; + auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t); -//////////////////////////////////////////////////////////////////////////////////////////////////// -int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace, - BatchedGemmData const& batchedGemmData, void* cudaStream, - int32_t /* multiProcessorCount */, bool usePdl, - std::optional> moduleCache) { - // Might be used. - (void)usePdl; - (void)moduleCache; - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, batchedGemmData); - - bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; - bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 && - options.mDtypeB == tg::Dtype::E4m3; - - auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData); - float* dPtrRowMax{nullptr}; - uint32_t* dPtrRowMaxBars{nullptr}; - - // Set the completion barriers to 0 if needed. - if (useDeepSeekFp8 && options.mFusedAct) { - dPtrRowMax = reinterpret_cast(alignPtr(reinterpret_cast(workspace), 1024)); - dPtrRowMaxBars = reinterpret_cast( - alignPtr(reinterpret_cast(dPtrRowMax) + workspaceSizes[0], 1024)); - auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1], - reinterpret_cast(cudaStream)); - if (err != cudaSuccess) { - return 1; + // TODO: do we need to pad to 1024? + workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024)); + workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024)); } - } - - auto [numCtaBatch, numCtaTile, numCtaInner] = - getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); - auto kernelParams = KernelParamsSetup::setKernelParams( - options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, - batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, - batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, - batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, - batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, - batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit, - batchedGemmData.mInputBuffers.mPtrGatedActAlpha, - batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, - dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, - batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); - - // The size of the grid. - std::vector grid = batchM ? std::vector{numCtaBatch, numCtaTile, numCtaInner} - : std::vector{numCtaTile, numCtaBatch, numCtaInner}; -#ifdef TLLM_GEN_EXPORT_INTERFACE - CUmodule cuModule; - CUfunction cuFunction; - - auto fiModuleLoadData = [&](CUmodule* module) { - const std::string sha256 = config.mHash ? config.mHash : ""; - std::string fname_cubin = config.mFunctionName; - if (!fname_cubin.empty()) { - fname_cubin[0] = static_cast(std::toupper(static_cast(fname_cubin[0]))); - } - fname_cubin = tllm_gen_bmm_cubin_path + "/" + fname_cubin + ".cubin"; - std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256); - cuModuleLoadData(&cuModule, cubin.c_str()); - }; - - if (moduleCache.has_value()) { - ModuleCache& moduleCacheRef = moduleCache.value().get(); - - // Modules are associated with a specific context, so the context is included in the key - CUcontext ctx; - unsigned long long ctxId; - cuCtxGetCurrent(&ctx); - cuCtxGetId(ctx, &ctxId); - - // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a - // string in decimal representation. - std::string const ctxName = - std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); - std::string const funcName = std::string(config.mFunctionName); - auto const moduleKey = ctxName + funcName; - auto module = moduleCacheRef.find(moduleKey); - - // Use cache if module is found, otherwise load and insert into cache - if (module != moduleCacheRef.end()) { - cuFunction = std::get<1>(module->second); - } else { - fiModuleLoadData(&cuModule); - cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); - moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); - } - } else { - fiModuleLoadData(&cuModule); - cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + return workspaceSizes; } - // Prepare the grid/block. - dim3 block3{static_cast(config.mNumThreadsPerCTA), static_cast(1), - static_cast(1)}; - dim3 grid3{(grid.size() > 0 ? static_cast(grid[0]) : 1u), - (grid.size() > 1 ? static_cast(grid[1]) : 1u), - (grid.size() > 2 ? static_cast(grid[2]) : 1u)}; - // Prepare the cluster size. - dim3 cluster3{static_cast(options.mClusterDimX), - static_cast(options.mClusterDimY), - static_cast(options.mClusterDimZ)}; - - // Run the kernel. - auto result = trtllm::gen::launchKernel( - (void*)&kernelParams, cudaStream, config.mSharedMemSize, cuFunction, block3, grid3, cluster3, - usePdl && (config.mOptions.mGridWaitForPrimaryEarlyExit | - config.mOptions.mGridWaitForPrimaryA | config.mOptions.mGridWaitForPrimaryB)); - if (result != CUDA_SUCCESS) { - return -1; - } - // If a module cache has not been given, unload the module to avoid leaking - if (!moduleCache.has_value()) { - cuModuleUnload(cuModule); + ////////////////////////////////////////////////////////////////////////////////////////////////// + + // Returns the size padded to the alignment + size_t getSizePaddedToAlignment(size_t size, size_t alignment) const { + assert((alignment & (alignment - 1)) == 0); + return (size + alignment - 1) & ~(alignment - 1); } -#else - config.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid); -#endif + ////////////////////////////////////////////////////////////////////////////////////////////////// - return 0; -} + private: + // Whether to export the cubin file. + bool mExportsCubin; + // The number of rotations. + int32_t mNumRotations; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h index 07dcd30be4..6e53d00c17 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h @@ -55,6 +55,13 @@ namespace batchedGemm { +namespace trtllm { +namespace gen { +class CudaRunner; +class GenCfg; +} // namespace gen +} // namespace trtllm + namespace batchedGemm { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -80,42 +87,47 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, - int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, - bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, - bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, - gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, - int mmaM, int mmaN, bool mockAllReduce, int n, int numRegsCastAWarps, - int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp, - int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, - int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, + int epilogueTileN, bool fuseUtccpWithUtcmma, bool gridTriggerSecondaryA, + bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, + bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, + gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, + int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, + int numEpilogueWarps, int numRegsCastAWarps, int numRegsCopySfLdsSttm, + int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, + int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, - bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, - bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, - gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector batchedM, - std::vector batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch, - int numTokens, RouteImpl routeImpl, std::optional routeSfsImpl, - bool gridWaitForPrimaryRouting, bool fusedAct, bool useTmaOobOpt) + bool useMaxTmemOverlap, bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, + bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, + int validM, int validN, int validK, int worldSize, + // GemmGatedActOptions + gemmGatedAct::ActType actType, bool clampBeforeAct, + // BatchedGemmOptions + std::vector batchedM, std::vector batchedN, BatchMode batchMode, bool fusedAct, + bool gridWaitForPrimaryRouting, bool isStaticBatch, int numBatches, int numRegsPerThreadLoadB, + int numRegsPerThreadLoadSfB, int numTokens, int numWarpsLoadB, int numWarpsLoadSfB, + RouteImpl routeImpl, std::optional routeSfsImpl, bool useTmaOobOpt) : gemmGatedAct::GemmGatedActOptions( gemm::GemmOptions( allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, - epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA, - gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, - gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, - layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, - numRegsCopySfLdsSttm, numSlicesForSplitK, numSlicesForSliceK, numStages, - numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, - numStagesWorkId, outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, - sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler, - transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, - useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB, - useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps, - useUnrollLoop2xForMma, worldSize), + epilogueLdtmBits, epilogueTileM, epilogueTileN, fuseUtccpWithUtcmma, + gridTriggerSecondaryA, gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, + gridWaitForPrimaryA, gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, + k, kernelTraits, layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, + numEpilogueWarps, numRegsCastAWarps, numRegsCopySfLdsSttm, + numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, + numSlicesForSliceK, numStages, numStagesMma, numStagesMmaWithinWorkTile, + numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp, + sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK, splitK, + tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule, + useDeepSeekFp8, useHoistTryWaitForCustomMmaSchedule, useMaxTmemOverlap, + usePerTokenSfA, usePerTokenSfB, useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, + useTwoMmaWarps, useUnrollLoop2xForMma, validM, validN, validK, worldSize), actType, clampBeforeAct), mBatchedM(batchedM), mBatchedN(batchedN), @@ -124,10 +136,11 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting), mIsStaticBatch(isStaticBatch), mNumBatches(numBatches), - mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), - mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), - mNumRegsCastAWarps(numRegsCastAWarps), + mNumRegsPerThreadLoadB{numRegsPerThreadLoadB}, + mNumRegsPerThreadLoadSfB{numRegsPerThreadLoadSfB}, mNumTokens(numTokens), + mNumWarpsLoadB{numWarpsLoadB}, + mNumWarpsLoadSfB{numWarpsLoadSfB}, mRouteImpl(routeImpl), mRouteSfsImpl(routeSfsImpl), mUseTmaOobOpt(useTmaOobOpt) {} @@ -147,14 +160,16 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { bool mIsStaticBatch{true}; // Number of Gemm batches. int mNumBatches; - // Number of registers per thread for non-epilogue warps - int mNumRegsPerThreadNonEpilogueWarp{0}; - // Number of registers per thread for epilogue warps - int mNumRegsPerThreadEpilogueWarp{0}; - // Number of registers for the cast A warps. - int mNumRegsCastAWarps{0}; + // Number of registers per thread for load B + int mNumRegsPerThreadLoadB{0}; + // Number of registers per thread for load SfB + int mNumRegsPerThreadLoadSfB{0}; // Total number of tokens. int mNumTokens{32}; + // Number of warps for load B + int mNumWarpsLoadB{0}; + // Number of warps for load SfB + int mNumWarpsLoadSfB{0}; // Whether load the input tokens and do routing. RouteImpl mRouteImpl{RouteImpl::NoRoute}; // Routing logic for scaling factors. If not specified, mRouteImpl is used. @@ -167,8 +182,8 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { //////////////////////////////////////////////////////////////////////////////////////////////////// // Check if the options are valid or not. -bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackwell, - bool updateOptions = true) { +inline bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackwell, + bool updateOptions = true) { bool isValid = true; if (options.mUseTmaOobOpt && !options.mUseTwoTmaLoadWarps) { if (updateOptions) { @@ -222,19 +237,21 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw if (options.mUseDeepSeekFp8) { if (batchM) { // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8. - TLLM_CHECK_ERROR(options.mN % 128 == 0, - "GEMM-N must be a multiple of 128 when using DeepSeek Fp8. Found ", - options.mN); + TLLM_CHECK_ERROR( + options.mN % 128 == 0 && options.mValidN % 128 == 0, + "GEMM-N and validN must be a multiple of 128 when using DeepSeek Fp8. Found ", options.mN, + " and validN=", options.mValidN); } else { // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8. - TLLM_CHECK_ERROR(options.mM % 128 == 0, - "GEMM-N must be a multiple of 128 when using DeepSeek Fp8. Found ", - options.mN); + TLLM_CHECK_ERROR( + options.mM % 128 == 0 && options.mValidM % 128 == 0, + "GEMM-M and validM must be a multiple of 128 when using DeepSeek Fp8. Found ", options.mM, + " and validM=", options.mValidM); } // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8. - TLLM_CHECK_ERROR(options.mK % 128 == 0, - "GEMM-K must be a multiple of 128 when using DeepSeek Fp8. Found ", - options.mK); + TLLM_CHECK_ERROR(options.mK % 128 == 0 && options.mValidK % 128 == 0, + "GEMM-K and validK must be a multiple of 128 when using DeepSeek Fp8. Found ", + options.mK, " and validK=", options.mValidK); TLLM_CHECK_ERROR(options.mDtypeC != tg::Dtype::E2m1 && options.mDtypeA == tg::Dtype::E4m3 && options.mDtypeB == tg::Dtype::E4m3, @@ -243,8 +260,10 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw if (options.mRouteSfsImpl.has_value() && options.mRouteSfsImpl.value() != options.mRouteImpl) { TLLM_CHECK_ERROR( - options.mRouteSfsImpl.value() == RouteImpl::Ldgsts && options.mRouteImpl == RouteImpl::Tma, - "RouteSfsImpl must be equal to RouteImpl, or Ldgsts, when RouteImpl is Tma"); + (options.mRouteSfsImpl.value() == RouteImpl::Ldgsts || + options.mRouteSfsImpl.value() == RouteImpl::LdgPlusSts) && + options.mRouteImpl == RouteImpl::Tma, + "RouteSfsImpl must be equal to RouteImpl, or Ldgsts/LdgPlusSts, when RouteImpl is Tma"); } else if (!options.mRouteSfsImpl.has_value()) { if (updateOptions) { options.mRouteSfsImpl = options.mRouteImpl; @@ -253,6 +272,16 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw return false; } } + + TLLM_CHECK_ERROR(options.mRouteImpl != RouteImpl::LdgPlusSts, + "LdgPlusSts does not support routing the tokens"); + + if (options.mRouteSfsImpl.has_value() && options.mRouteSfsImpl.value() == RouteImpl::LdgPlusSts) { + TLLM_CHECK_ERROR(!batchM, "LdgPlusSts only supports batch N"); + TLLM_CHECK_ERROR(options.mTileK <= 512 && options.mTileK >= 128, + "LdgPlusSts only supports 128 <= tileK <= 512"); + } + if (batchM) { if (options.mDtypeA == tg::Dtype::MxE2m1 && options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) { TLLM_CHECK_ERROR(doesRouteImplUseNoRoute(options.mRouteImpl), @@ -326,6 +355,7 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw "2CTA BatchedGemm does not support routing along M dimension. To support it, " "change the input routing data layout to be padded to clusterDimX size."); } + return isValid; } @@ -336,19 +366,18 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw //////////////////////////////////////////////////////////////////////////////////////////////////// struct BatchedGemmConfig { - // When TRT-LLM Gen is exported to the other frameworks, the TLLM_GEN_EXPORT_INTERFACE must be - // defined. In this case, the cubins will be loaded from the provided data and function name. - // Otherwise, the kernel will be loaded from the CudaRunner. -#ifdef TLLM_GEN_EXPORT_INTERFACE uint8_t const* mData{nullptr}; - uint32_t const mSize{0}; - uint32_t const mSharedMemSize{0}; + uint32_t mSize{0}; + uint32_t mSharedMemSize{0}; char const* mFunctionName{nullptr}; - uint32_t const mNumThreadsPerCTA{0}; + uint32_t mNumThreadsPerCTA{0}; char const* mHash{nullptr}; -#else + + std::string mGenCfgJsonStr{""}; + char const* mExecPath{nullptr}; trtllm::gen::CudaRunner* mCudaRunner{nullptr}; -#endif + trtllm::gen::GenCfg* mGenCfg{nullptr}; + int32_t mInstanceIdx{0}; BatchedGemmOptions mOptions; gemm::SmVersion mSm{gemm::SmVersion::Sm100a}; @@ -356,27 +385,32 @@ struct BatchedGemmConfig { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline std::string dumpOptions(BatchedGemmOptions const& options) { +inline std::string dumpOptions(BatchedGemmOptions const& options, bool dumpRuntimeParams = true) { std::stringstream ss; - ss << gemmGatedAct::dumpOptions(options) << ", "; - ss << "mBatchedM={}," << std::endl; - ss << "mBatchedN={}," << std::endl; + ss << gemmGatedAct::dumpOptions(options, dumpRuntimeParams) << ", "; + if (dumpRuntimeParams) { + ss << "mBatchedM={}," << std::endl; + ss << "mBatchedN={}," << std::endl; + } ss << "mBatchMode=batchedGemm::BatchedGemmOptions::BatchMode(" << static_cast(options.mBatchMode) << ")," << std::endl; - ss << "mNumBatches=" << options.mNumBatches << "," << std::endl; + ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; + ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; ss << "mIsStaticBatch=" << options.mIsStaticBatch << "," << std::endl; - ss << "mNumTokens=" << options.mNumTokens << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mNumBatches=" << options.mNumBatches << "," << std::endl; + } + ss << "mNumRegsPerThreadLoadB=" << options.mNumRegsPerThreadLoadB << "," << std::endl; + ss << "mNumRegsPerThreadLoadSfB=" << options.mNumRegsPerThreadLoadSfB << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mNumTokens=" << options.mNumTokens << "," << std::endl; + } + ss << "mNumWarpsLoadB=" << options.mNumWarpsLoadB << "," << std::endl; + ss << "mNumWarpsLoadSfB=" << options.mNumWarpsLoadSfB << "," << std::endl; ss << "mRouteImpl=batchedGemm::RouteImpl(" << static_cast(options.mRouteImpl) << ")," << std::endl; ss << "mRouteSfsImpl={batchedGemm::RouteImpl(" << static_cast(options.mRouteSfsImpl.value()) << ")}," << std::endl; - ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; - ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; - ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," - << std::endl; - ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," - << std::endl; - ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl; return ss.str(); } diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h index 1086cd4fd5..559118916d 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h @@ -45,6 +45,13 @@ namespace batchedGemm { +namespace trtllm { +namespace gen { +class CudaRunner; +class GenCfg; +} // namespace gen +} // namespace trtllm + namespace gemmGatedAct { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -130,10 +137,6 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& "Unsupported output hidden tile size"); } - if (options.mUseDeepSeekFp8) { - TLLM_CHECK_ERROR(hiddenSize % 256 == 0, "Output hidden size must be a multiple of 256"); - } - if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2; int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); @@ -148,6 +151,21 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& return false; } + auto const validHiddenSize = options.mTransposeMmaOutput ? options.mValidM : options.mValidN; + if (options.mUseDeepSeekFp8) { + TLLM_CHECK_ERROR(hiddenSize % 256 == 0 && validHiddenSize % 256 == 0, "Hidden size (", + hiddenSize, ") and valid hidden size (", validHiddenSize, + ") must be a multiple of 256"); + } + + // + if (options.mUseShuffledMatrixA) { + auto const shuffleBlockSize = gemm::getShuffleBlockSize(options.mEpilogueTileM); + TLLM_CHECK_ERROR( + hiddenSize % (2 * shuffleBlockSize) == 0 && validHiddenSize % (2 * shuffleBlockSize) == 0, + "M/validM must be a multiple of 2 * shuffle block size (", 2 * shuffleBlockSize, + ") when useShuffledMatrixA"); + } if (options.mNumSlicesForSplitK > 1) { TLLM_CHECK_ERROR(doesSplitKUseDsmem(options.mSplitK), "Split-k GMEM and GemmGatedAct are not supported yet."); @@ -163,11 +181,11 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& //////////////////////////////////////////////////////////////////////////////////////////////////// -inline std::string dumpOptions(GemmGatedActOptions const& options) { +inline std::string dumpOptions(GemmGatedActOptions const& options, bool dumpRuntimeParams = true) { std::stringstream ss; - ss << gemm::dumpOptions(options) << ", "; - ss << "mActType=" << "gemmGatedAct::ActType(" << static_cast(options.mActType) << ")," - << std::endl; + ss << gemm::dumpOptions(options, dumpRuntimeParams) << ", "; + ss << "mActType=" + << "gemmGatedAct::ActType(" << static_cast(options.mActType) << ")," << std::endl; ss << "mClampBeforeAct=" << options.mClampBeforeAct << "" << std::endl; return ss.str(); } @@ -179,19 +197,18 @@ inline std::string dumpOptions(GemmGatedActOptions const& options) { //////////////////////////////////////////////////////////////////////////////////////////////////// struct GemmGatedActConfig { - // When TRT-LLM Gen is exported to the other frameworks, the TLLM_GEN_EXPORT_INTERFACE must be - // defined. In this case, the cubins will be loaded from the provided data and function name. - // Otherwise, the kernel will be loaded from the CudaRunner. -#ifdef TLLM_GEN_EXPORT_INTERFACE uint8_t const* mData{nullptr}; - uint32_t const mSize{0}; - uint32_t const mSharedMemSize{0}; + uint32_t mSize{0}; + uint32_t mSharedMemSize{0}; char const* mFunctionName{nullptr}; - uint32_t const mNumThreadsPerCTA{0}; + uint32_t mNumThreadsPerCTA{0}; char const* mHash{nullptr}; -#else + + std::string mGenCfgJsonStr{""}; + char const* mExecPath{nullptr}; trtllm::gen::CudaRunner* mCudaRunner{nullptr}; -#endif + trtllm::gen::GenCfg* mGenCfg{nullptr}; + int32_t mInstanceIdx{0}; GemmGatedActOptions mOptions{}; gemm::SmVersion mSm{gemm::SmVersion::Sm100a}; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h index fc3bd88101..af6432f7a0 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h @@ -30,7 +30,14 @@ #include "trtllm/gen/CudaRunner.h" #include "trtllm/gen/GenCtx.h" #else +#ifdef TLLM_GEN_EXPORT_FLASHINFER +#include +namespace flashinfer::trtllm_cubin_loader { +std::string getCubin(const std::string& kernelName, const std::string& sha256); +} +#endif // TLLM_GEN_EXPORT_FLASHINFER #include +namespace batchedGemm { template void printArgs(T arg) { @@ -72,7 +79,12 @@ void printArgs(T first, Args... args) { #endif // TLLM_GEN_EXPORT_INTERFACE -namespace batchedGemm { +namespace trtllm { +namespace gen { +class CudaRunner; +class GenCfg; +} // namespace gen +} // namespace trtllm namespace gemm { @@ -91,28 +103,29 @@ struct GemmOptions { #endif GemmOptions() = default; - GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, int clusterDimX, int clusterDimY, int clusterDimZ, CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, - int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, - bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, - bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, - bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, MatrixLayout layoutA, - MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, - bool mockAllReduce, int n, int numRegsCopySfLdsSttm, int numSlicesForSplitK, - int numSlicesForSliceK, int numStages, int numStagesMma, - int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, - bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, - tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, - int sfReshapeFactor, bool sliceK, SplitK splitK, int tileK, int tileM, int tileN, - TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, - bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, + int epilogueTileM, int epilogueTileN, bool fuseUtccpWithUtcmma, + bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, + bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, + bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, + MatrixLayout layoutA, MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, + int mmaM, int mmaN, bool mockAllReduce, int n, int numEpilogueWarps, + int numRegsCastAWarps, int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp, + int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, + int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, + int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, + bool patchF2fp, std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, + tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int sfReshapeFactor, bool sliceK, + SplitK splitK, int tileK, int tileM, int tileN, TileScheduler tileScheduler, + bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8, + bool useHoistTryWaitForCustomMmaSchedule, bool useMaxTmemOverlap, bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, - bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, - int worldSize) + bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int validM, + int validN, int validK, int worldSize) : mAllReduceAlgo{allReduceAlgo}, mBiasType{biasType}, mBlockK(blockK), @@ -133,6 +146,7 @@ struct GemmOptions { mEpilogueLdtmBits{epilogueLdtmBits}, mEpilogueTileM{epilogueTileM}, mEpilogueTileN{epilogueTileN}, + mFuseUtccpWithUtcmma{fuseUtccpWithUtcmma}, mGridTriggerSecondaryA{gridTriggerSecondaryA}, mGridTriggerSecondaryB{gridTriggerSecondaryB}, mGridWaitForPrimaryEarlyExit{gridWaitForPrimaryEarlyExit}, @@ -151,7 +165,11 @@ struct GemmOptions { mMmaN{mmaN}, mMockAllReduce{mockAllReduce}, mN{n}, + mNumEpilogueWarps{numEpilogueWarps}, + mNumRegsCastAWarps(numRegsCastAWarps), mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm), + mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), + mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), mNumSlicesForSplitK{numSlicesForSplitK}, mNumSlicesForSliceK{numSlicesForSliceK}, mNumStages{numStages}, @@ -176,6 +194,7 @@ struct GemmOptions { mUseCustomMmaSchedule{useCustomMmaSchedule}, mUseDeepSeekFp8{useDeepSeekFp8}, mUseHoistTryWaitForCustomMmaSchedule{useHoistTryWaitForCustomMmaSchedule}, + mUseMaxTmemOverlap{useMaxTmemOverlap}, mUsePerTokenSfA{usePerTokenSfA}, mUsePerTokenSfB{usePerTokenSfB}, mUseShuffledMatrixA{useShuffledMatrixA}, @@ -183,6 +202,9 @@ struct GemmOptions { mUseTwoTmaLoadWarps{useTwoTmaLoadWarps}, mUseTwoMmaWarps{useTwoMmaWarps}, mUseUnrollLoop2xForMma{useUnrollLoop2xForMma}, + mValidM{validM}, + mValidN{validN}, + mValidK{validK}, mWorldSize{worldSize} {} // The all-reduce algorithm. @@ -233,6 +255,8 @@ struct GemmOptions { int mEpilogueTileM{128}; // Tile size for the epilogue in N dimension. int mEpilogueTileN{32}; + // Whether fuse UTCCP with UTC*MMA. + bool mFuseUtccpWithUtcmma{false}; // Whether load task A triggers the next grid. bool mGridTriggerSecondaryA{false}; // Whether load task B triggers the next grid. @@ -269,8 +293,16 @@ struct GemmOptions { bool mMockAllReduce{false}; // The N dimension of GEMM. int mN{64 * 4}; + // Number of Epilogue Warps + int mNumEpilogueWarps{4}; + // Number of registers for the cast A warps. + int mNumRegsCastAWarps{0}; // Number of registers for the LDS+STTM warps. int mNumRegsCopySfLdsSttm{0}; + // Number of registers per thread for epilogue warps + int mNumRegsPerThreadEpilogueWarp{0}; + // Number of registers per thread for non-epilogue warps + int mNumRegsPerThreadNonEpilogueWarp{0}; // Number of partitions along the K dimension. When mNumSlicesForSplitK > 1, // the problem is distributed across several SMs, where each CTA works on its local K slice. // Partial results are accumulated afterwards using either GMEM or DSMEM (in CGA) @@ -329,6 +361,8 @@ struct GemmOptions { // k-block. It benefits when the next k-block is already available and thus sustaining the // momentum, but it adds latency to the first k-block for smaller k-loop. bool mUseHoistTryWaitForCustomMmaSchedule{false}; + // Whether use the max Tmem overlap trick. + bool mUseMaxTmemOverlap{false}; // Apply per-token scales from A bool mUsePerTokenSfA{false}; // Apply per-token scales from B @@ -343,6 +377,15 @@ struct GemmOptions { bool mUseTwoMmaWarps{false}; // Whether to unroll the loop by 2x. bool mUseUnrollLoop2xForMma{true}; + // The valid range of M/N/K dimension of GEMM without padding values. + // Used to opportunistically remove memory traffic from the padding due to rigid SF shape + // constraint or TMA constraint. Such as: + // 1. outputDim % (4 * sfBlockSize) == 0; as 4x SFs are packed into 4 bytes + // 2. MxFp4 x Fp8 mmaType requires bespoke TMA load which requires hiddenDim % 128 == 0 + // 3. TMA requires 16B alignment for each row + int mValidM{-1}; + int mValidN{-1}; + int mValidK{-1}; // World size for all-reduce. int mWorldSize{1}; }; @@ -365,19 +408,17 @@ inline bool isSmVersionBlackwell(SmVersion smVersion) { //////////////////////////////////////////////////////////////////////////////////////////////////// struct GemmConfig { - // When TRT-LLM Gen is exported to the other frameworks, the TLLM_GEN_EXPORT_INTERFACE must be - // defined. In this case, the cubins will be loaded from the provided data and function name. - // Otherwise, the kernel will be loaded from the CudaRunner. -#ifdef TLLM_GEN_EXPORT_INTERFACE uint8_t const* mData{nullptr}; - uint32_t const mSize{0}; - uint32_t const mSharedMemSize{0}; + uint32_t mSize{0}; + uint32_t mSharedMemSize{0}; char const* mFunctionName{nullptr}; - uint32_t const mNumThreadsPerCTA{0}; + uint32_t mNumThreadsPerCTA{0}; char const* mHash{nullptr}; -#else + std::string mGenCfgJsonStr{""}; + char const* mExecPath{nullptr}; trtllm::gen::CudaRunner* mCudaRunner{nullptr}; -#endif + trtllm::gen::GenCfg* mGenCfg{nullptr}; + int32_t mInstanceIdx{0}; GemmOptions mOptions{}; SmVersion mSm{SmVersion::Sm100a}; @@ -407,7 +448,7 @@ inline std::string toString(trtllm::gen::MmaKind e) { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline std::string dumpOptions(GemmOptions const& options) { +inline std::string dumpOptions(GemmOptions const& options, bool dumpRuntimeParams = true) { std::stringstream ss; ss << "mAllReduceAlgo=" << "gemm::AllReduceAlgo(" << static_cast(options.mAllReduceAlgo) << ")" @@ -447,6 +488,7 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mEpilogueLdtmBits=" << options.mEpilogueLdtmBits << "," << std::endl; ss << "mEpilogueTileM=" << options.mEpilogueTileM << "," << std::endl; ss << "mEpilogueTileN=" << options.mEpilogueTileN << "," << std::endl; + ss << "mFuseUtccpWithUtcmma=" << options.mFuseUtccpWithUtcmma << "," << std::endl; ss << "mGridTriggerSecondaryA=" << options.mGridTriggerSecondaryA << "," << std::endl; ss << "mGridTriggerSecondaryB=" << options.mGridTriggerSecondaryB << "," << std::endl; ss << "mGridWaitForPrimaryEarlyExit=" << options.mGridWaitForPrimaryEarlyExit << "," << std::endl; @@ -454,14 +496,18 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mGridWaitForPrimaryB=" << options.mGridWaitForPrimaryB << "," << std::endl; ss << "mHoistLoadTaskInit=" << options.mHoistLoadTaskInit << "," << std::endl; ss << "mHoistMmaTaskTryWaits=" << options.mHoistMmaTaskTryWaits << "," << std::endl; - ss << "mK=" << options.mK << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mK=" << options.mK << "," << std::endl; + } ss << "mKernelTraits={}" << "," << std::endl; ss << "mLayoutA=gemm::MatrixLayout(" << static_cast(options.mLayoutA) << ")" << "," << std::endl; ss << "mLayoutB=gemm::MatrixLayout(" << static_cast(options.mLayoutB) << ")" << "," << std::endl; - ss << "mM=" << options.mM << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mM=" << options.mM << "," << std::endl; + } ss << "mMmaK=" << options.mMmaK << "," << std::endl; ss << "mMmaKind=" << "trtllm::gen::MmaKind(" << static_cast(options.mMmaKind) << ")" @@ -469,8 +515,16 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mMmaM=" << options.mMmaM << "," << std::endl; ss << "mMmaN=" << options.mMmaN << "," << std::endl; ss << "mMockAllReduce=" << options.mMockAllReduce << "," << std::endl; - ss << "mN=" << options.mN << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mN=" << options.mN << "," << std::endl; + } + ss << "mNumEpilogueWarps=" << options.mNumEpilogueWarps << "," << std::endl; + ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; ss << "mNumRegsCopySfLdsSttm=" << options.mNumRegsCopySfLdsSttm << "," << std::endl; + ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," + << std::endl; + ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," + << std::endl; ss << "mNumSlicesForSplitK=" << options.mNumSlicesForSplitK << "," << std::endl; ss << "mNumSlicesForSliceK=" << options.mNumSlicesForSliceK << "," << std::endl; ss << "mNumStages=" << options.mNumStages << "," << std::endl; @@ -512,6 +566,7 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mUseDeepSeekFp8=" << options.mUseDeepSeekFp8 << "," << std::endl; ss << "mUseHoistTryWaitForCustomMmaSchedule=" << options.mUseHoistTryWaitForCustomMmaSchedule << "," << std::endl; + ss << "mUseMaxTmemOverlap=" << options.mUseMaxTmemOverlap << "," << std::endl; ss << "mUsePerTokenSfA=" << options.mUsePerTokenSfA << "," << std::endl; ss << "mUsePerTokenSfB=" << options.mUsePerTokenSfB << "," << std::endl; ss << "mUseShuffledMatrixA=" << options.mUseShuffledMatrixA << "," << std::endl; @@ -519,7 +574,12 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mUseTwoTmaLoadWarps=" << options.mUseTwoTmaLoadWarps << "," << std::endl; ss << "mUseTwoMmaWarps=" << options.mUseTwoMmaWarps << "," << std::endl; ss << "mUseUnrollLoop2xForMma=" << options.mUseUnrollLoop2xForMma << "," << std::endl; - ss << "mWorldSize=" << options.mWorldSize << std::endl; + if (dumpRuntimeParams) { + ss << "mValidM=" << options.mValidM << "," << std::endl; + ss << "mValidN=" << options.mValidN << "," << std::endl; + ss << "mValidK=" << options.mValidK << "," << std::endl; + ss << "mWorldSize=" << options.mWorldSize << std::endl; + } return ss.str(); } @@ -578,6 +638,51 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } + // If validM/N/K is not specified, then assume the full range of the dimension is valid. + if (options.mValidM == -1) { + options.mValidM = options.mM; + } + if (options.mValidN == -1) { + options.mValidN = options.mN; + } + if (options.mValidK == -1) { + options.mValidK = options.mK; + } + + // It must not exceed the padded dimensions. + if (options.mValidM > options.mM || options.mValidN > options.mN || + options.mValidK > options.mK) { + TLLM_LOG_WARNING( + options.mValidK <= options.mK, + "ValidM, ValidN, and ValidK must be less than or equal to M, N, and K respectively."); + if (updateOptions) { + options.mValidM = std::min(options.mValidM, options.mM); + options.mValidN = std::min(options.mValidN, options.mN); + options.mValidK = std::min(options.mValidK, options.mK); + } else { + return false; + } + } + + // BlockMajorK layout does not support validM, validN, validK parameters + if (options.mLayoutA == gemm::MatrixLayout::BlockMajorK || + options.mLayoutB == gemm::MatrixLayout::BlockMajorK) { + bool hasValidParams = (options.mValidM != -1 && options.mValidM != options.mM) || + (options.mValidN != -1 && options.mValidN != options.mN) || + (options.mValidK != -1 && options.mValidK != options.mK); + TLLM_CHECK_ERROR(!hasValidParams, + "BlockMajorK layout does not support validM/validN/validK parameters due to " + "swizzled layout. " + "Found validM=", + options.mValidM, " validN=", options.mValidN, " validK=", options.mValidK); + } + +#ifdef TLLM_PUBLIC_RELEASE + if (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3) { + TLLM_CHECK_ERROR(false, "E2m1 x E4m3 is not supported for JIT compile. Use cubins instead."); + } +#endif // TLLM_PUBLIC_RELEASE + // Check that the A cast is supported. // Currently, we only support {MxFp4, NvFp4} -> Bf16. TLLM_CHECK_ERROR( @@ -607,6 +712,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mDtypeA == tg::Dtype::MxE2m1 && options.mDtypeMmaA == tg::Dtype::Bfloat16, "PatchF2fp is only supported for MxFp4 to Bf16 casts."); } +#ifdef TLLM_PUBLIC_RELEASE + options.mPatchF2fp = false; +#endif // TLLM_PUBLIC_RELEASE // FIXME: We do not support different dtypes for A and B when not on Blackwell. if (!isBlackwell) { @@ -819,7 +927,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in (padMultiplierB * tg::dtypeGetNumBits(options.mDtypeB) * options.mK / 8) % 16 == 0, "K dimension of B must be aligned to 16 bytes."); - if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { + if (tg::dtypeIsBlockFmt(options.mDtypeC)) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); TLLM_CHECK_ERROR( @@ -836,6 +944,10 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); TLLM_CHECK_ERROR(hiddenDim % hiddenGranularity == 0, "Hidden dim (", hiddenDim, ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); + int const validHiddenDim = options.mTransposeMmaOutput ? options.mValidM : options.mValidN; + TLLM_CHECK_ERROR(validHiddenDim % tg::dtypeNumEltsPerSf(options.mDtypeC) == 0, + "Valid hidden dim (", validHiddenDim, ") must be a multiple of ", + tg::dtypeNumEltsPerSf(options.mDtypeC), " for block-scaled outputs."); TLLM_CHECK_ERROR(!options.mTransposeMmaOutput || options.mUseShuffledMatrixA, "Transposing block-scaled outputs requires shuffled A."); } @@ -901,8 +1013,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (options.mUseShuffledMatrixA) { auto const shuffleBlockSize = getShuffleBlockSize(options.mEpilogueTileM); - TLLM_CHECK_ERROR(options.mM % shuffleBlockSize == 0, - "M must be a multiple of shuffle block size (", shuffleBlockSize, + TLLM_CHECK_ERROR(options.mM % shuffleBlockSize == 0 && options.mValidM % shuffleBlockSize == 0, + "M/validM must be a multiple of shuffle block size (", shuffleBlockSize, ") when useShuffledMatrixA"); } @@ -1084,9 +1196,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // options.mUseTwoMmaWarps = true; // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8. - TLLM_CHECK_ERROR(options.mK % 128 == 0, - "GEMM-K must be a multiple of 128 when using DeepSeek Fp8. Found ", - options.mK); + TLLM_CHECK_ERROR(options.mK % 128 == 0 && options.mValidK % 128 == 0, + "GEMM-K and validK must be a multiple of 128 when using DeepSeek Fp8. Found ", + options.mK, " and validK=", options.mValidK); // Check that the output tile N can be processed with the epilogue tile granularity. TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerEpilogueTile == 0, @@ -1100,6 +1212,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in hiddenDimPerMma, ")"); } + TLLM_CHECK_ERROR(options.mNumEpilogueWarps == 4 || options.mNumEpilogueWarps == 8, + "mNumEpilogueWarps has to be either 4 or 8."); + if (options.mSliceK) { TLLM_CHECK_ERROR(isBlackwell, "Slice-K is not supported on Hopper"); @@ -1253,7 +1368,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "At least one matrix must be in k-major layout"); // Some features are currently only support when both matrices are in K-major format - if (options.mLayoutB != MatrixLayout::MajorK || options.mLayoutB != MatrixLayout::MajorK) { + if (options.mLayoutA != MatrixLayout::MajorK || options.mLayoutB != MatrixLayout::MajorK) { TLLM_CHECK_ERROR(isBlackwell, "Non K-major layouts are only supported on Blackwell"); TLLM_CHECK_ERROR(options.mSplitK == SplitK::None, "Non K-major layouts do not support split K"); } @@ -1303,6 +1418,31 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "Bias is not supported for Meta Fp8"); } + if (options.mUseMaxTmemOverlap) { + TLLM_CHECK_ERROR(options.mUseTmaStore, "mUseMaxTmemOverlap only works with TMA store"); + TLLM_CHECK_ERROR(options.mFuseUtccpWithUtcmma, + "mUseMaxTmemOverlap only works with mFuseUtccpWithUtcmma"); + TLLM_CHECK_ERROR(options.mNumSlicesForSplitK == 1, + "mUseMaxTmemOverlap does not work with splitK"); + TLLM_CHECK_ERROR(options.mNumSlicesForSliceK == 1, + "mUseMaxTmemOverlap does not work with sliceK"); + TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, + "mUseMaxTmemOverlap does not work with mUseDeepSeekFp8"); + TLLM_CHECK_ERROR(!options.mUseUnrollLoop2xForMma, + "mUseMaxTmemOverlap does not work with mUseUnrollLoop2xForMma"); + } + + if (options.mNumEpilogueWarps > 4) { + TLLM_CHECK_ERROR(options.mUseTmaStore, + "Using more than 4 warps for epilogue only works with TMA store"); + TLLM_CHECK_ERROR(options.mNumSlicesForSplitK == 1, + "Using more than 4 warps for epilogue does not work with splitK"); + TLLM_CHECK_ERROR(options.mNumSlicesForSliceK == 1, + "Using more than 4 warps for epilogue does not work with sliceK"); + TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, + "Using more than 4 warps for epilogue does not work with mUseDeepSeekFp8"); + } + if (updateOptions) { // Init kernel traits. options.mKernelTraits = KernelTraits( @@ -1311,6 +1451,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mTileK, options.mEpilogueTileM, options.mEpilogueTileN, options.mNumStages, options.mNumStagesMma, options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, options.mSplitK, options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, + options.mFuseUtccpWithUtcmma, options.mUseMaxTmemOverlap, options.mNumEpilogueWarps, options.mTileScheduler == TileScheduler::Persistent, options.mUseDeepSeekFp8, options.mUsePerTokenSfA, options.mUsePerTokenSfB, /* useTwoCtas*/ options.mClusterDimX == 2, options.mBiasType); @@ -1321,6 +1462,59 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in //////////////////////////////////////////////////////////////////////////////////////////////////// +inline bool getDoesScaleC(tg::Dtype dtypeC) { + // Need to scale/quantize the output C matrix when the output type is Fp8 or NvFp4. + return dtypeC == tg::Dtype::E4m3 || dtypeC == tg::Dtype::E2m1; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool getDoesScaleAb(tg::Dtype dtypeA, tg::Dtype dtypeB, bool useDeepSeekFp8) { + // Need to scale/dequantize the input A/B matrices when the input type is Fp8 or NvFp4 and + // DeepSeekFp8 is not used. + bool const doesScaleAb{ + dtypeA == tg::Dtype::E2m1 || dtypeB == tg::Dtype::E2m1 || + ((dtypeA == tg::Dtype::E4m3 || dtypeB == tg::Dtype::E4m3) && !useDeepSeekFp8)}; + return doesScaleAb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool getKernelDoesScaleC(tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, + bool useDeepSeekFp8) { + // In the Gemm/BatchedGemm kernels, dequantScaleAb and quantScaleC are combined into one single + // scaling factor (called scaleC). As a result, we combine the logic for getDoesScaleAb and + // getDoesScaleC. + return getDoesScaleC(dtypeC) || getDoesScaleAb(dtypeA, dtypeB, useDeepSeekFp8); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline CUresult loadCubinData(CUmodule* module, Config const& config) { + // Trtllm links the cubin into the executable while Flashinfer loads the cubin from storage. +#ifdef TLLM_GEN_EXPORT_FLASHINFER +#ifdef TLLM_GEN_GEMM_CUBIN_PATH + static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH); + const std::string sha256 = config.mHash ? config.mHash : ""; + std::string fileName = config.mFunctionName; + if (!fileName.empty()) { + fileName[0] = static_cast(std::toupper(static_cast(fileName[0]))); + } + const std::string& data = flashinfer::trtllm_cubin_loader::getCubin( + tllm_gen_gemm_cubin_path + "/" + fileName + ".cubin", sha256); + CUresult result = cuModuleLoadData(module, data.c_str()); +#else + static_assert(false, "TLLM_GEN_GEMM_CUBIN_PATH macro is not defined when compiling"); +#endif // TLLM_GEN_GEMM_CUBIN_PATH +#else + CUresult result = cuModuleLoadData(module, config.mData); +#endif // TLLM_GEN_EXPORT_FLASHINFER + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace gemm #ifdef TLLM_GEN_EXPORT_INTERFACE diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h index 7e0474bb5f..800c8546ef 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h @@ -82,8 +82,18 @@ bool useTmaOobOptC(BatchedGemmOptions const& options) { // Create the TMA shape/stride for A/B/C. template -static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, int mK, int tileM, - int tileN, int tileK, MatrixType matrixType) { +static auto makeTmaShapeStrideAbc(GemmOptions const& options, int sizeM, int sizeN, int sizeK, + int tileM, int tileN, int tileK, MatrixType matrixType, + int validM = -1, int validN = -1, int validK = -1) { + if (validM == -1) { + validM = sizeM; + } + if (validN == -1) { + validN = sizeN; + } + if (validK == -1) { + validK = sizeK; + } // Weights matrix is A if we transpose the output of MMA (to have it M-major). // Otherwise, it is B, when the output of MMA is K-major. bool const isWeights = (matrixType == MatrixType::MatrixA && options.mTransposeMmaOutput) || @@ -96,9 +106,11 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in : matrixType == MatrixType::MatrixC ? useTmaOobOptC(options) : false; - // The outer dimension. + // The outer dimension. Uses padded dimensions for strides and valid dimensions for shapes. auto numTokens = - (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? mM : mN; + (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? sizeM : sizeN; + auto numTokensValid = + (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? validM : validN; // The outer dimension tile size. auto ctaTileNumTokens = (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? tileM : tileN; @@ -107,7 +119,8 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileM : ctaTileNumTokens; // The inner dimension. - auto hiddenSize = (matrixType == MatrixType::MatrixC) ? mN : mK; + auto hiddenSize = (matrixType == MatrixType::MatrixC) ? sizeN : sizeK; + auto hiddenSizeValid = (matrixType == MatrixType::MatrixC) ? validN : validK; // The inner dimension tile size. auto ctaTileHiddenSize = (matrixType == MatrixType::MatrixC) ? tileN : tileK; // The inner dimension of TMA box shape. @@ -117,6 +130,7 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // Swap matrix C sizes if output is transposed. if (matrixType == MatrixType::MatrixC && options.mTransposeMmaOutput) { std::swap(numTokens, hiddenSize); + std::swap(numTokensValid, hiddenSizeValid); std::swap(ctaTileNumTokens, ctaTileHiddenSize); std::swap(tileNumTokens, tileHiddenSize); } @@ -125,6 +139,7 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // gated activations but not regular activations. if (options.mFusedAct && matrixType == MatrixType::MatrixC) { hiddenSize /= 2; + hiddenSizeValid /= 2; tileHiddenSize /= 2; ctaTileHiddenSize /= 2; } @@ -134,17 +149,18 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // 1, so swap the first two dimension so that the hiddenSize dimension comes first. // Activations matrix is 2D (sum(divUpMul(M[bi], tileM) for bi in B), K). - std::vector shape = {static_cast(hiddenSize), - static_cast(numTokens)}; + // Use valid dimensions for shape. + std::vector shape = {static_cast(hiddenSizeValid), + static_cast(numTokensValid)}; if (useTmaOobOpt /* also implies input/output activation */) { // If TMA OOB optimization is used: // Shape [hidden, tokens] Stride [1, hidden] becomes // Shape [hidden, tileN, TmaDimMax, TmaDimMax] Stride [1, hidden, XLargeN - hidden, hidden] - shape = {static_cast(hiddenSize), static_cast(ctaTileNumTokens), + shape = {static_cast(hiddenSizeValid), static_cast(ctaTileNumTokens), static_cast(tg::TmaDimMax), static_cast(tg::TmaDimMax)}; } else if (isWeights) { // If the matrix is a weights matrix, we use 3D logical shape (B, M, K) or (B, N, K). - shape = {static_cast(hiddenSize), static_cast(numTokens), + shape = {static_cast(hiddenSizeValid), static_cast(numTokensValid), static_cast(options.mNumBatches)}; } @@ -177,10 +193,11 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in stride[1] = numTokens; std::swap(tileShape[0], tileShape[1]); } else if (layout == gemm::MatrixLayout::BlockMajorK) { - // Set shapes based on blocking layout + // Set shapes based on blocking layout. shape = {static_cast(options.mBlockK), static_cast(numTokens), - static_cast(mK / options.mBlockK), + static_cast(sizeK / options.mBlockK), static_cast(options.mNumBatches)}; + // Strides use padded dimensions stride = {1, static_cast(options.mBlockK), static_cast(numTokens * options.mBlockK), static_cast(hiddenSize * numTokens)}; @@ -209,17 +226,6 @@ static auto makeTmaShapeStrideSfAb(int mM, int mN, int mK, MatrixType matrixType switch (layout) { case tg::SfLayout::R128c4: { - // The scaling factor tensor packs 128x4 tiles into contiguous 512B blocks. - // The 512B block maps to a 32x16B (32x128b) block in TMEM. - // See https://nvbugspro.nvidia.com/bug/4165523 - // - // Additionally, we have to meet constraints of TMA that the box dimensions are less - // than 256 and boxDim[0] is a multiple of 16B. - // - // The "logical" tensor is: [outer, inner / numEltsPerSf] - // The aforementioned format is: [outer / 128, inner / numEltsPerSf / 4, 512] - // The shape we use for TMA is: [outer / 128, inner / numEltsPerSf / 4, 2, 256] - auto shape = std::vector{ 256, 2, static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4)), static_cast(ceilDiv(numTokens, 128))}; @@ -294,7 +300,6 @@ static auto makeTmaShapeStrideSfAb(int mM, int mN, int mK, MatrixType matrixType } return std::make_tuple(std::vector{}, std::vector{}, std::vector{}); } - template static KernelParams setKernelParams( GemmOptions_ const& options, bool const batchM, void const* ptrA, void const* ptrB, void* ptrC, @@ -390,9 +395,9 @@ static KernelParams setKernelParams( params.tileStridePerBatch = options.mM / options.mTileM; params.nm = options.mM; // Shape/stride for gmem tensor A. - auto [shapeA, strideA, tileShapeA] = - makeTmaShapeStrideAbc(options, options.mM, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixA); + auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAbc( + options, options.mM, options.mN, options.mK, options.mTileM, options.mTileN, options.mTileK, + MatrixType::MatrixA, options.mValidM, options.mValidN, options.mValidK); // Build tma descriptor for A. params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, strideA, tileShapeA, const_cast(ptrA)); @@ -469,15 +474,17 @@ static KernelParams setKernelParams( // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( - options, options.mM, ctaOffset * options.mTileN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixC); + auto [shapeC, strideC, tileShapeC] = + makeTmaShapeStrideAbc(options, options.mM, ctaOffset * options.mTileN, options.mK, + options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixC, + options.mValidM, ctaOffset * options.mTileN, options.mValidK); // Build tma descriptor for C. params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, strideC, tileShapeC, ptrC); } else { params.ptrC = ptrC; } + } else { // B is the expert if (0 != options.mN % options.mTileN) { @@ -486,9 +493,9 @@ static KernelParams setKernelParams( params.tileStridePerBatch = options.mN / options.mTileN; params.nm = options.mN; // Shape/stride for gmem tensor B. - auto [shapeB, strideB, tileShapeB] = - makeTmaShapeStrideAbc(options, options.mM, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixB); + auto [shapeB, strideB, tileShapeB] = makeTmaShapeStrideAbc( + options, options.mM, options.mN, options.mK, options.mTileM, options.mTileN, options.mTileK, + MatrixType::MatrixB, options.mValidM, options.mValidN, options.mValidK); // Build tma descriptor for B. params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, strideB, tileShapeB, const_cast(ptrB)); @@ -544,9 +551,10 @@ static KernelParams setKernelParams( // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( - options, ctaOffset * options.mTileM, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixC); + auto [shapeC, strideC, tileShapeC] = + makeTmaShapeStrideAbc(options, ctaOffset * options.mTileM, options.mN, options.mK, + options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixC, + ctaOffset * options.mTileM, options.mValidN, options.mValidK); // Build tma descriptor for C. params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, strideC, tileShapeC, ptrC); diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h index 16b4af3149..e11374739f 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h @@ -29,54 +29,6 @@ struct KernelParams { // Maximum number of CTAs in the batch-token dimension. static constexpr int MaxNumCtas = 2048; - // NOTE: TMA out-of-bounds optimization for MoE padded tokens: - // - // Originally the padded tokens is a 2D tensor [hiddenDim, ctaGridDimY * tileN] with stride [1, - // hiddenDim] and box size [tileM, tileN] at pointer p. We waste bandwidth bytes since we only - // want to load [0, batchEnd) out of the [0, tileN) box size: batchEnd is a runtime variable while - // box size needs to be fixed at compile time. - // - // To deal with this, we reshape the tensor to 3D: [hiddenDim, tileN, ctaGridDimY * tileN] with - // stride [1, hiddenDim, hiddenDim] and box size [tileM, tileN, 1]. For the original 2D - // tensor, - // - // Offset Coords [ : , ctaIdxY * tileN ], - // Box Sizes [ : , tileN ], - // Coords Range [ : , ctaIdxY * tileN : ctaIdxY * tileN + tileN], - // - // while we only want load the range [ctaIdxY * tileN, ctaIdxY * tileN + batchEnd), 1 <= batchEnd - // <= tileN - // - // For the reshaped 3D tensor, - // - // Offset Coords [ : , tileN - batchEnd , - // ctaIdxY * tileN + batchEnd ], - // Box Sizes [ : , tileN , - // 1 ], - // Coords Range [ : , tileN - batchEnd : min(tileN, 2 * tileN - batchEnd), - // ctaIdxY * tileN + batchEnd : ctaIdx * tileN + batchEnd + 1], - // - // while min(tileN, 2 * tileN - batchEnd) always evaluates to tileN. The unwanted tokens are - // essentially filtered out by utilizing the OOB feature of TMA. Since the 2nd and 3rd dimension - // has the same stride, we end up loading the following (adding the left and right end of the 2nd - // and 3rd dimension ranges): - // - // Effective 2D Coords Range - // [ : , tileN + ctaIdxY * tileN : tileN + ctaIdxY * tileN + batchEnd], - // - // This is exactly the same as the original range except for the offset tileN, thus we also need - // to offset the pointer in the opposite direction: - // - // Ptr (p) -> Ptr (p - tileN * hiddenDim) - // - // Due to the restrictions of TMA unit, the above operations requires the TMA descriptor and the - // underlying buffer be constructed differently: - // - Requires valid buffer at (p - tileN * hidden) - needs prepending `tileN` tokens. - // - TMA outermost dimension must be extended by `tileN` or loads will OOB in the rightmost side. - // The latter is because when batchEnd == tileN, the offset coords in the 3rd dimension becomes - // ctaIdxY * tileN + tileN. When ctaIdxY = ctaGridDimY - 1, it becomes ((ctaGridDimY - 1) * tileN - // + tileN = ctaGridDimY * tileN which is equal to the 3rd dimension size and will be filtered - // out. That's why we need to extend the tensor size by tileN. // // TMA descriptor for A. // Must be setup using gemm::buildNdTmaDescriptor with shapes and strides from diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h index 4d79f83c23..4ea0a91250 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include "Enums.h" @@ -162,9 +163,12 @@ class KernelTraits { int32_t epilogueTileN, int32_t numStages, int32_t numStagesMma, int32_t numSlicesForSplitK, int32_t numSlicesForSliceK, SplitK splitK, bool useTmaStore, bool transposeMmaOutput, AllReduceAlgo allReduceAlgo, + bool fuseUtccpWithUtcmma, bool useMaxTmemOverlap, int32_t numEpilogueWarps, bool usePersistentScheduler, bool useDeepSeekFp8, bool usePerTokenSfA, bool usePerTokenSfB, bool useTwoCtas, BiasType biasType) - : mMmaKind{mmaKind} { + : mMmaKind{mmaKind}, + mFuseUtccpWithUtcmma{fuseUtccpWithUtcmma}, + mUseMaxTmemOverlap{useMaxTmemOverlap} { // // SMEM // @@ -271,6 +275,10 @@ class KernelTraits { extraGmemCMultiplier = 0; } + if (numEpilogueWarps) { + extraGmemCMultiplier *= numEpilogueWarps / 4; + } + // Number of bytes to store the output in smem. auto const numBytesSmemStoreC = usesSmemForGmemC ? extraGmemCMultiplier * epilogueTileM * epilogueTileN * @@ -418,8 +426,11 @@ class KernelTraits { std::vector tmemChunkNames; // Matrix D { + // Two set of TMEM resources for D share epilogueTileN columns, + // | set0:epiTileN0 | set0:epiTileN1/set1:epiTileN0 | set1:epiTileN1 | + auto const numCols = mUseMaxTmemOverlap ? 2 * tileN - epilogueTileN : tileN; // Number of columns for accumulators. - auto const numTmemColsD = numSlicesForSliceK * tileN * numStagesMma * + auto const numTmemColsD = numSlicesForSliceK * numCols * numStagesMma * tg::dtypeGetNumBits(dtypeAcc) / tg::dtypeGetNumBits(tg::Dtype::UInt32); // Number of columns for D alignment. @@ -466,9 +477,9 @@ class KernelTraits { auto const numTmemColsSfA = useConstSfA ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK), 4) - : (useBlockScalingA - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * numStages - : 0); + : (useBlockScalingA ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * + (mFuseUtccpWithUtcmma ? 1 : numStages) + : 0); // Number of columns for Sf alignment. auto const numColsAlignmentSfA = 4; // No need to reuse TMEM. @@ -491,9 +502,9 @@ class KernelTraits { auto const numTmemColsSfB = useConstSfB ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK), 4) - : (useBlockScalingB - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * numStages - : 0); + : (useBlockScalingB ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * + (mFuseUtccpWithUtcmma ? 1 : numStages) + : 0); // Number of columns for Sf alignment. auto const numColsAlignmentSfB = 4; // No need to reuse TMEM. @@ -515,6 +526,10 @@ class KernelTraits { public: // The MMA kind. tg::MmaKind mMmaKind; + // Whether fuse Utccp into the MMA task. + bool mFuseUtccpWithUtcmma; + // Whether use the max TMEM overlap trick. + bool mUseMaxTmemOverlap; // Helper for SMEM allocation. MemAllocatorHelper mSmemAllocatorHelper; // Helper for TMEM allocation. diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h index a1412444ae..c7b18af138 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h @@ -156,7 +156,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + ss << "Error: Failed to initialize the TMA descriptor. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; @@ -251,7 +251,7 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor for SF " << errorString << std::endl; + ss << "Error: Failed to initialize the TMA descriptor for SF. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h index c7f1020dea..53155c8ffb 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h @@ -38,8 +38,6 @@ constexpr unsigned long XLargeN = 1UL << 35; //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline T ceilDiv(T m, T n) { return (m + n - T(1)) / n; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h index 965bb1b7b8..56b537ff42 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h @@ -63,8 +63,6 @@ enum class SfLayout { // | 1,0 | 1,1 | 1,2 | 1,3 | 33,0 | 33,1 | 33,2 | 33,3 | ... | 97,3 | // | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | // | 31,0 | 31,1 | 31,2 | 31,3 | 63,0 | 63,1 | 63,2 | 63,3 | ... | 127,3 | - // See https://nvbugspro.nvidia.com/bug/4165523 - // // I.e., the SF buffer is a tensor [โŒˆm/128โŒ‰, โŒˆn/b/4โŒ‰, 32, 4, 4] // The SF for the element (i, j) is stored at (i/128, j/b/4, i%32, (i%128)/32, (j/b)%4). R128c4, diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 50d3baecc7..e3a0d21884 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -113,58 +113,67 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported pair"); \ } +#define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mPaddingLog2 > 0) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, true), kernel, numBlocks, numThreads, smemSize, \ + stream); \ + } else { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, false), kernel, numBlocks, numThreads, \ + smemSize, stream); \ + } + #define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/), kernel, \ - numBlocks, numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \ - numBlocks, numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } else { \ FLASHINFER_WARN("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, extraFlag, numExperts) \ - if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, float, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ +#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, extraFlag, numExperts) \ + if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, float, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeExpW"); \ } #define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ @@ -182,17 +191,17 @@ namespace moe::dev { #define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ stream, extraFlag1, numExperts) \ if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \ + numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ + numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ - kernel, numBlocks, numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \ - kernel, numBlocks, numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } else { \ FLASHINFER_WARN("Unsupported dtypeExpW"); \ } diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh index dd7d5c474d..d110037269 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh @@ -67,6 +67,24 @@ __host__ __device__ constexpr T divUpMulLog2(T a, T bLog2) { return mulLog2(divUpLog2(a, bLog2), bLog2); } +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__host__ __device__ constexpr T mulTileN(T a, T tileN) { + return a * tileN; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__host__ __device__ constexpr T divUpTileN(T a, T tileN) { + return (a + tileN - 1) / tileN; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__host__ __device__ constexpr T divUpMulTileN(T a, T tileN) { + return divUpTileN(a, tileN) * tileN; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// __host__ __device__ constexpr int32_t getBits(int32_t value, int idx) { @@ -299,7 +317,14 @@ __device__ void routingPermutation(KernelParams params, // Compute the runtime config for projections // Whether or not an expert is local is taken into account when smemExpertCount is computed // so we do not need to take it into account here. - const int32_t numCta = divUpLog2(count, params.mPaddingLog2); + + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(count, params.mPaddingLog2); + } else { + numCta = divUpTileN(count, params.mTileTokensDim); + } + int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); @@ -310,21 +335,37 @@ __device__ void routingPermutation(KernelParams params, const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = - min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffset, params.mPaddingLog2) + count); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } // get the padded offset associated with this expert - const int32_t offset = mulLog2(ctaOffset, params.mPaddingLog2); - + int32_t offset; + if constexpr (KernelParams::isPow2) { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } // write expert offsets to shared smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; } // write out padded count if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } @@ -513,14 +554,25 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // Compute the runtime config for projections // Whether or not an expert is local is taken into account when the histogram is computed // so we do not need to take it into account here. - const int32_t numCta = divUpLog2(count, params.mPaddingLog2); + // const int32_t numCta = divUpLog2(count, params.mPaddingLog2); + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(count, params.mPaddingLog2); + } else { + numCta = divUpTileN(count, params.mTileTokensDim); + } int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); if (threadIdx.x < params.mNumExperts) { // Get the padded offset associated with this expert - const int32_t offset = mulLog2(ctaOffset, params.mPaddingLog2); + int32_t offset; + if constexpr (KernelParams::isPow2) { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } // Write expert offsets to shared smemExpertOffset[threadIdx.x] = offset; @@ -532,7 +584,12 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // The first block writes out padded count if (blockIdx.x == 0 && warpIdx == KernelParams::MaxNumExperts / WarpSize - 1 && cute::elect_one_sync()) { - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } @@ -543,9 +600,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = - min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffset, params.mPaddingLog2) + count); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } } diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index e424d91db0..cae6729368 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -50,7 +50,7 @@ struct DataBase { // dim: [mNumTokens * mTopK] int32_t* mPtrExpandedIdxToPermutedIdx{nullptr}; // optional: if `nullptr`, it is not filled - // dim: [mNumTokens * mTopK + (mNumExperts << mPaddingLog2) - mNumExperts] + // dim: [mTileTokensDim * mTopK + (mNumExperts ร— mTileTokensDim) - mNumExperts] // Note: this array (mPtrPermutedIdxToTokenIdx) is uninitialized // Any out-of-bounds values are undefined. int32_t* mPtrPermutedIdxToTokenIdx{nullptr}; @@ -93,6 +93,7 @@ struct DataBase { int32_t mNumExperts; int32_t mTopK; int32_t mPaddingLog2; + int32_t mTileTokensDim; /// For expert parallelization int32_t mLocalExpertsStartIdx; @@ -100,11 +101,12 @@ struct DataBase { int32_t mNumLocalExperts; }; -template +template struct KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; static constexpr int MaxNumExperts = MaxNumExperts_; + static constexpr bool isPow2 = isPow2_; static constexpr bool UsePdl = UsePdl_; // Public pointer members @@ -123,7 +125,8 @@ struct KernelParamsBase { int32_t mNumTokens = 0; int32_t mNumExperts = 0; - int32_t mPaddingLog2 = 0; + int32_t mPaddingLog2 = -1; + int32_t mTileTokensDim = 0; int32_t mLocalExpertsStartIdx = 0; int32_t mLocalExpertsStrideLog2 = 0; int32_t mNumLocalExperts = 0; @@ -146,6 +149,7 @@ struct KernelParamsBase { mNumExperts = data.mNumExperts; mPaddingLog2 = data.mPaddingLog2; + mTileTokensDim = data.mTileTokensDim; mLocalExpertsStartIdx = data.mLocalExpertsStartIdx; mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2; mNumLocalExperts = data.mNumLocalExperts; @@ -173,8 +177,8 @@ struct Data : public DataBase { }; template -struct KernelParams : public KernelParamsBase { + bool isPow2_, bool UsePdl_> +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using BiasT = BiasT_; using OutputT = OutputT_; @@ -229,8 +233,8 @@ struct Data : public DataBase { tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; @@ -268,8 +272,8 @@ struct Data : public DataBase { }; template -struct KernelParams : public KernelParamsBase { + bool isPow2_, bool UsePdl_> +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index df19e00310..65f497ad90 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2084,80 +2084,57 @@ def run_moe_test( ) -# Test: DeepSeekV3 routing +# Test: Renormalize routing @pytest.mark.parametrize("num_tokens", [1, 8, 1024]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( "moe_impl", [ + pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), - pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), ], ) @pytest.mark.parametrize( "routing_config", [ - pytest.param( - { - "num_experts": 384, - "top_k": 8, - "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], - "compatible_intermediate_size": [512, 1024, 2048], - }, - id="kimi_k2", - ), pytest.param( { "num_experts": 256, "top_k": 8, "padding": 8, - "n_groups": 8, - "top_k_groups": 4, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], - "compatible_intermediate_size": [512, 1024, 2048], + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], + "compatible_intermediate_size": [384, 768, 1024, 2048], }, - id="DSv3", + id="Renorm", ), pytest.param( { - "num_experts": 72, - "top_k": 6, + "num_experts": 512, + "top_k": 10, "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], - "compatible_intermediate_size": [384, 768], + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], + "compatible_intermediate_size": [512], }, - id="DSLite", + id="Qwen3_next", ), ], ) @pytest.mark.parametrize( "weight_processing", [ - pytest.param( - { - "use_shuffled_weight": False, - "layout": WeightLayout.MajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], - }, - id="NoShuffle_MajorK", - ), pytest.param( { "use_shuffled_weight": True, @@ -2166,14 +2143,6 @@ def run_moe_test( }, id="Shuffled_MajorK", ), - pytest.param( - { - "use_shuffled_weight": True, - "layout": WeightLayout.BlockMajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], - }, - id="Shuffled_BlockMajorK", - ), ], ) @pytest.mark.parametrize( @@ -2183,7 +2152,7 @@ def run_moe_test( pytest.param(GatedActType.GeGlu, id="GeGlu"), ], ) -def test_deepseekv3_routing( +def test_renormalize_routing( num_tokens, hidden_size, intermediate_size, @@ -2193,7 +2162,7 @@ def test_deepseekv3_routing( gated_act_type, cache_permute_indices, ): - """Test DeepSeekV3 routing configurations.""" + """Test Renormalize routing configurations.""" run_moe_test( num_tokens, hidden_size, @@ -2206,58 +2175,80 @@ def test_deepseekv3_routing( ) -# Test: Renormalize routing +# Test: DeepSeekV3 routing @pytest.mark.parametrize("num_tokens", [1, 8, 1024]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( "moe_impl", [ + pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), - pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), ], ) @pytest.mark.parametrize( "routing_config", [ + pytest.param( + { + "num_experts": 384, + "top_k": 8, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [512, 1024, 2048], + }, + id="kimi_k2", + ), pytest.param( { "num_experts": 256, "top_k": 8, "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], - "compatible_intermediate_size": [384, 768, 1024, 2048], + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [512, 1024, 2048], }, - id="Renorm", - marks=pytest.mark.skip(reason="Skip temporary"), + id="DSv3", ), pytest.param( { - "num_experts": 512, - "top_k": 10, + "num_experts": 72, + "top_k": 6, "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], - "compatible_intermediate_size": [512], + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [384, 768], }, - id="Qwen3_next", + id="DSLite", ), ], ) @pytest.mark.parametrize( "weight_processing", [ + pytest.param( + { + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="NoShuffle_MajorK", + ), pytest.param( { "use_shuffled_weight": True, @@ -2266,6 +2257,14 @@ def test_deepseekv3_routing( }, id="Shuffled_MajorK", ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.BlockMajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="Shuffled_BlockMajorK", + ), ], ) @pytest.mark.parametrize( @@ -2275,7 +2274,7 @@ def test_deepseekv3_routing( pytest.param(GatedActType.GeGlu, id="GeGlu"), ], ) -def test_renormalize_routing( +def test_deepseekv3_routing( num_tokens, hidden_size, intermediate_size, @@ -2285,7 +2284,7 @@ def test_renormalize_routing( gated_act_type, cache_permute_indices, ): - """Test Renormalize routing configurations.""" + """Test DeepSeekV3 routing configurations.""" run_moe_test( num_tokens, hidden_size, From 26d587a44aa4f9005c01529467c4a5965b709d26 Mon Sep 17 00:00:00 2001 From: FlashInfer Bot Date: Wed, 5 Nov 2025 23:29:17 -0800 Subject: [PATCH 028/130] chore: Update CODEOWNERS (#1984) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR updates the CODEOWNERS file based on git commit history analysis from the last 180 days. ## Changes - Updated `.github/CODEOWNERS` with current code ownership based on: - Commit frequency - File coverage - Commit recency ## How to Review 1. Review the changes to `.github/CODEOWNERS` 2. Verify that the assigned owners are appropriate for each module 3. Make manual adjustments if needed before merging ## Notes - This is an automated PR generated weekly - Minimum commits threshold: 1 - Analysis period: 180 days - Directory depth: 3 levels - Top N owners per module: 5 --- ๐Ÿค– This PR was automatically generated by the [update-codeowners workflow](.github/workflows/update-codeowners.yml) ## Summary by CodeRabbit * **Chores** * Updated code ownership assignments and reorganized related section mappings for internal development processes. Co-authored-by: flashinfer-bot Co-authored-by: Claude --- .github/CODEOWNERS | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0897d7f37c..2e26c661f6 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,21 +3,21 @@ # Analysis period: 180 days # Minimum commits threshold: 1 -benchmarks/ @bkryu @cyx-6 @nv-yunzheq @kahyunnam @jiahanc +benchmarks/ @bkryu @cyx-6 @jiahanc @nv-yunzheq @kahyunnam benchmarks/routines/ @bkryu @nv-yunzheq @cyx-6 @nvmbreughe @Anerudhan ci/ @cyx-6 @yzh119 @nvmbreughe ci/scripts/ @cyx-6 ci/scripts/jenkins/ @cyx-6 csrc/ @wenscarl @yzh119 @cyx-6 @djmmoss @yongwww -csrc/fused_moe/ @yzh119 @yongwww @djmmoss @wenscarl @cyx-6 -csrc/fused_moe/cutlass_backend/ @yzh119 @yongwww @djmmoss @wenscarl @cyx-6 -csrc/nv_internal/ @wenscarl @djmmoss @yzh119 @cyx-6 @yongwww +csrc/fused_moe/ @yzh119 @yongwww @djmmoss @cyx-6 @wenscarl +csrc/fused_moe/cutlass_backend/ @yzh119 @yongwww @djmmoss @cyx-6 @wenscarl +csrc/nv_internal/ @wenscarl @djmmoss @cyx-6 @yzh119 @yongwww csrc/nv_internal/cpp/ @wenscarl @yongwww @djmmoss @joker-eph @ttyio csrc/nv_internal/include/ @wenscarl -csrc/nv_internal/tensorrt_llm/ @wenscarl @djmmoss @yzh119 @cyx-6 @yongwww -csrc/xqa/ @yzh119 @cyx-6 +csrc/nv_internal/tensorrt_llm/ @wenscarl @djmmoss @cyx-6 @yzh119 @yongwww +csrc/xqa/ @cyx-6 @yzh119 docs/ @yzh119 @cyx-6 @wenscarl @nv-yunzheq @aleozlx -flashinfer/ @yzh119 @cyx-6 @wenscarl @nvmbreughe @bkryu +flashinfer/ @yzh119 @cyx-6 @wenscarl @nvmbreughe @yongwww flashinfer-cubin/ @yzh119 @cyx-6 flashinfer-cubin/flashinfer_cubin/ @yzh119 flashinfer-jit-cache/ @yzh119 @cyx-6 @@ -26,18 +26,18 @@ flashinfer/comm/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @djmmoss flashinfer/cudnn/ @Anerudhan @yzh119 @cyx-6 @Anerudhan flashinfer/cute_dsl/ @yzh119 @kaixih @Amir-19 @aleozlx flashinfer/fused_moe/ @djmmoss @yzh119 @cyx-6 @wenscarl @IwakuraRein -flashinfer/jit/ @yzh119 @cyx-6 @djmmoss @aleozlx @yongwww -flashinfer/jit/attention/ @yzh119 @Anerudhan @joker-eph +flashinfer/jit/ @yzh119 @cyx-6 @djmmoss @jiahanc @aleozlx +flashinfer/jit/attention/ @yzh119 @cyx-6 @Anerudhan @joker-eph flashinfer/jit/gemm/ @yzh119 flashinfer/logits_processor/ @cyx-6 @yzh119 flashinfer/profiler/ @cyx-6 flashinfer/triton/ @cyx-6 @nvmbreughe @yzh119 flashinfer/tuning_configs/ @kaixih -include/ @yzh119 @cyx-6 @wenscarl @kahyunnam @joker-eph -include/flashinfer/ @yzh119 @cyx-6 @wenscarl @kahyunnam @joker-eph +include/ @yzh119 @wenscarl @kahyunnam @joker-eph @cyx-6 +include/flashinfer/ @yzh119 @wenscarl @kahyunnam @joker-eph @cyx-6 include/flashinfer/attention/ @yzh119 @kahyunnam @joker-eph include/flashinfer/comm/ @yongwww @nvmbreughe @djmmoss @yzh119 @cyx-6 -include/flashinfer/gemm/ @ttyio @yongwww @aleozlx @cyx-6 +include/flashinfer/gemm/ @ttyio @yongwww @aleozlx include/flashinfer/trtllm/ @joker-eph @aleozlx @yzh119 @cyx-6 @wenscarl profiler/ @cyx-6 scripts/ @yzh119 @nvmbreughe @dierksen @yongwww @bkryu From aacc8dfe390b9e0c99f9ae8db4b2ff9506b06ccd Mon Sep 17 00:00:00 2001 From: ChristinaZ <83400082+ChristinaZ@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:55:39 +0800 Subject: [PATCH 029/130] Add support for topkPacked input in block-level renormalize (#2051) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Add support for topkPacked input in block-level renormalize ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Performance** * Optimized routing layer efficiency through improved index handling in specialized processing configurations. Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> --- csrc/trtllm_fused_moe_routing_renormalize.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 56939f8d02..91b8fc5075 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -143,6 +143,14 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } } } // end if (validToken) + } else if (params.mPtrTopKPacked != nullptr) { + if (validToken) { + if (laneIdx < params.mTopK) { + int offset = + warpIdx * MaxNumExperts + params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].idx; + smemKIdx[offset] = static_cast(laneIdx); + } + } } __syncthreads(); From f25929f158c6bfb015704ab48f09b77d512e0329 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Thu, 6 Nov 2025 02:21:38 -0800 Subject: [PATCH 030/130] test: Skip test_fp8_quantize.py on Hopper (#2052) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description The unit test `test_fp8_quantize.py` currently fails on sm90. Root cause: The test file tests the accuracy of `mxfp8_quantize()`. However, in [fp8_quantization.py](https://github.com/flashinfer-ai/flashinfer/blob/adb0e89fdee0a3140a43982bc3bef4e79ce20046/flashinfer/fp8_quantization.py#L7), the `mxfp8_quantize()`'s underlying module only exists for `gen_mxfp8_quantization_sm100_module` with no sm90 support. Current PR changes test file to skip for pre-SM100 SM archs as they are not supported.. Results: * Before current PR on SM90: `72 failed, 40 passed in 2.69s` * After current PR on SM90: `40 passed, 72 skipped in 1.41s` * Before current PR on SM120: `112 passed in 1.59s` * After current PR on SM120: `112 passed in 1.54s` (expected to be the same as before) ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Added conditional checks to skip FP8 quantization tests on GPUs that lack required computational capabilities. --- tests/utils/test_fp8_quantize.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/utils/test_fp8_quantize.py b/tests/utils/test_fp8_quantize.py index 50352eacc1..a9fe4c41c7 100644 --- a/tests/utils/test_fp8_quantize.py +++ b/tests/utils/test_fp8_quantize.py @@ -2,6 +2,7 @@ import torch from flashinfer import mxfp8_dequantize_host, mxfp8_quantize +from flashinfer.utils import get_compute_capability @pytest.mark.parametrize("m", [1, 1024]) @@ -10,6 +11,13 @@ @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_mxfp8_quantize_torch(m, k, dtype, is_sf_swizzled_layout, device): + if device == "cuda": + major, _ = get_compute_capability(torch.device(device)) + if major < 10: + pytest.skip( + "mxfp8 quantization is not supported on compute capability < 10" + ) + a = 16 * torch.randn([m, k], dtype=dtype).to(device).contiguous() if device == "cpu": @@ -90,6 +98,10 @@ def test_mxfp8_quantize_torch_host(m, k, dtype, is_sf_swizzled_layout): @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) def test_mxfp8_quantize_torch_device(m, k, dtype, is_sf_swizzled_layout): + major, _ = get_compute_capability(torch.device("cuda:0")) + if major < 10: + pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + torch.random.manual_seed(0) a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous() @@ -114,6 +126,10 @@ def test_mxfp8_quantize_torch_device(m, k, dtype, is_sf_swizzled_layout): def test_mxfp8_quantize_alignment_torch_device( m, k, dtype, is_sf_swizzled_layout, alignment ): + major, _ = get_compute_capability(torch.device("cuda:0")) + if major < 10: + pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + torch.random.manual_seed(0) a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous() padded_k = ((k + alignment - 1) // alignment) * alignment From 55ea78719434916870d819f6278d5ed4fa977c19 Mon Sep 17 00:00:00 2001 From: Lain Date: Thu, 6 Nov 2025 13:33:34 -0800 Subject: [PATCH 031/130] [BUG] Fix trtllm-gen fp4 moe renormalize routing (#2049) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Temporarily disable `routingIndicesBlockKernel` as it's not compatible with the current packing format (topk-id and expert weights are packed into a 32 bit tensor). This solves the issue https://github.com/flashinfer-ai/flashinfer/issues/2032 ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Forced multi-block MoE execution to avoid sporadic single-block selection and improve stability with certain workloads. * **New Features** * Added an alternative packed topโ€‘k routing input path that propagates routing scores when present. * **Tests** * Added a comprehensive parametrized test validating routed fused MoE across token counts, model sizes, expert counts and multiple quantization modes. --------- Signed-off-by: Siyuan Fu Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> Co-authored-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> --- csrc/trtllm_fused_moe_routing_renormalize.cu | 12 +- tests/moe/test_trtllm_gen_routed_fused_moe.py | 244 ++++++++++++++++++ 2 files changed, 253 insertions(+), 3 deletions(-) create mode 100644 tests/moe/test_trtllm_gen_routed_fused_moe.py diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 91b8fc5075..d3a63431a8 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -146,9 +146,13 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } else if (params.mPtrTopKPacked != nullptr) { if (validToken) { if (laneIdx < params.mTopK) { - int offset = - warpIdx * MaxNumExperts + params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].idx; + int offset = warpIdx * MaxNumExperts + + static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].idx); smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = + static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].score); + } } } } @@ -430,7 +434,9 @@ void run(Data const& data, void* stream) { TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0) << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; - bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; + // FIXME: routingIndicesBlockKernel breaks the vllm + gpt-oss DeepEP + // bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; + bool const useSingleBlock = false; bool const useSingleCluster = data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py new file mode 100644 index 0000000000..8bda03e971 --- /dev/null +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -0,0 +1,244 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +from typing import Literal +import torch + +from flashinfer import ( + RoutingMethodType, + GatedActType, + fp4_quantize, + mxfp8_quantize, +) +from flashinfer.fused_moe import ( + trtllm_fp4_block_scale_moe, + trtllm_fp4_block_scale_routed_moe, +) +from flashinfer.utils import device_support_pdl + +from .test_trtllm_gen_fused_moe import ( + routing_reference_renormalize, + routing_reference_renormalize_naive, + routing_reference_topk, +) + + +@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("hidden_size", [1024, 2048, 3072, 4096]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048, 3072, 4096]) +@pytest.mark.parametrize("num_experts", [128, 256]) +@pytest.mark.parametrize("top_k", [4, 8]) +@pytest.mark.parametrize( + "routing_method_type", + [ + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + RoutingMethodType.TopK, + ], +) +@pytest.mark.parametrize("quant_mode", ["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"]) +def test_trtllm_gen_routed_fused_moe( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + num_experts: int, + routing_method_type: RoutingMethodType, + quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], +): + torch.manual_seed(42) + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( + torch.bfloat16 + ) + hidden_states = ( + torch.randn(num_tokens, hidden_size, device=device).to(torch.bfloat16) * 0.1 + ) + if quant_mode == "NvFP4xNvFP4": + hidden_states, hidden_states_scale = fp4_quantize( + hidden_states, + torch.tensor([448.0 * 6.0], device=device), + sf_vec_size=16, + sf_use_ue8m0=False, + is_sf_swizzled_layout=False, + ) + hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( + num_tokens, -1 + ) + hidden_states_global_scale = 1.0 / 448.0 / 6.0 + elif quant_mode == "MxFP4xMxFP8": + hidden_states, hidden_states_scale = mxfp8_quantize(hidden_states, False) + hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( + num_tokens, -1 + ) + hidden_states_global_scale = 1.0 + else: # MxFP4xBf16 + hidden_states_scale = None + hidden_states_global_scale = 1.0 + + w13 = ( + torch.randn(num_experts, intermediate_size * 2, hidden_size, device=device).to( + torch.bfloat16 + ) + * 0.1 + ) + w2 = ( + torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( + torch.bfloat16 + ) + * 0.1 + ) + if quant_mode == "NvFP4xNvFP4": + w13, w13_scale = fp4_quantize( + w13, + torch.tensor([448.0 * 6.0], device=device), + sf_vec_size=16, + sf_use_ue8m0=False, + ) + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + num_experts, intermediate_size * 2, -1 + ) + w2, w2_scale = fp4_quantize( + w2, + torch.tensor([448.0 * 6.0], device=device), + sf_vec_size=16, + sf_use_ue8m0=False, + ) + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, -1 + ) + w13_global_scale = 1.0 / 448.0 / 6.0 + w2_global_scale = 1.0 / 448.0 / 6.0 + else: + w13, w13_scale = fp4_quantize( + w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True + ) + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + num_experts, intermediate_size * 2, -1 + ) + w2, w2_scale = fp4_quantize( + w2, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True + ) + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, -1 + ) + w13_global_scale = 1.0 + w2_global_scale = 1.0 + + output1_scale_scalar = torch.tensor( + [hidden_states_global_scale * w13_global_scale] * num_experts, device=device + ) + output1_scale_gate_scalar = torch.tensor( + [hidden_states_global_scale * w13_global_scale] * num_experts, device=device + ) + output2_scale_scalar = torch.tensor( + [hidden_states_global_scale * w2_global_scale] * num_experts, device=device + ) + + reference_output = trtllm_fp4_block_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + hidden_states_scale, + w13, + w13_scale, + None, # w13_bias + None, # gemm1_alpha + None, # gemm1_beta + None, # gemm1_clamp_limit + w2, + w2_scale, + None, # w2_bias + output1_scale_scalar, + output1_scale_gate_scalar, + output2_scale_scalar, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + None, # routed_scaling_factor + None, # tile_tokens_dim + routing_method_type.value, + True, # do_finalize + enable_pdl, + GatedActType.SwiGlu.value, # gated_act_type + None, + )[0].to(torch.float) + + if routing_method_type == RoutingMethodType.Renormalize: + permute_info, expert_weights = routing_reference_renormalize( + routing_logits, top_k, num_experts, 8 + ) + elif routing_method_type == RoutingMethodType.RenormalizeNaive: + permute_info, expert_weights = routing_reference_renormalize_naive( + routing_logits, top_k, num_experts, 8 + ) + elif routing_method_type == RoutingMethodType.TopK: + permute_info, expert_weights = routing_reference_topk( + routing_logits, top_k, num_experts, 8 + ) + topk_ids = permute_info["topKIndices"].to(torch.int32) + expert_weights = expert_weights.view(num_tokens, num_experts)[ + torch.arange(num_tokens).unsqueeze(1), topk_ids + ].to(torch.bfloat16) + + packed_tensor = (topk_ids.to(torch.int32) << 16) | expert_weights.to( + torch.bfloat16 + ).view(torch.int16) + + output = trtllm_fp4_block_scale_routed_moe( + packed_tensor, + None, # routing_bias + hidden_states, + hidden_states_scale, + w13, + w13_scale, + None, # w13_bias + None, # gemm1_alpha + None, # gemm1_beta + None, # gemm1_clamp_limit + w2, + w2_scale, + None, # w2_bias + output1_scale_scalar, + output1_scale_gate_scalar, + output2_scale_scalar, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + None, # routed_scaling_factor + None, # tile_tokens_dim + routing_method_type.value, + True, # do_finalize + enable_pdl, + GatedActType.SwiGlu.value, # gated_act_type + None, + )[0].to(torch.float) + + mask = torch.isclose(output, reference_output, rtol=1e-3, atol=1e-3) + + # mismatch percentage + mismatch_pct = (~mask).float().mean().item() * 100 + assert mismatch_pct < 6, f"Mismatch percentage is {mismatch_pct:.2f}" From 63cf56227e8200e4a6a70dffcefec6542b91a756 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Thu, 6 Nov 2025 16:27:41 -0800 Subject: [PATCH 032/130] release: Bump version for v0.5.2 release (#2057) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Version updated to 0.5.2 --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 4b9fcbec10..cb0c939a93 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.5.1 +0.5.2 From adcc5dd41037bbd77a68800b15f5e0235c2975ac Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 6 Nov 2025 16:58:30 -0800 Subject: [PATCH 033/130] perf: improve sampling/mask/softmax performance (part 1/2) (#2044) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This is the first part of the performance improvement PR for sampling/mask/softmax operator, in this PR, we defer the cross thread reduction till the end of the loop (similar to how FA2 handles denominator) to reduce the number of shuffling and thread sync instructions. For the second part of the PR, we will implement the Radix TopK algorithm to improve top-k mask logits performance when K is small. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Added comprehensive benchmarking suite for sampling and softmax operations with performance comparison and visualization tools. * **Chores** * Optimized internal kernel execution strategies for improved performance efficiency. --- benchmarks/bench_sampling.py | 80 ++++++++++ benchmarks/bench_softmax.py | 214 ++++++++++++++++++++++++++ include/flashinfer/sampling.cuh | 261 +++++++++++++++++--------------- 3 files changed, 431 insertions(+), 124 deletions(-) create mode 100755 benchmarks/bench_softmax.py diff --git a/benchmarks/bench_sampling.py b/benchmarks/bench_sampling.py index 2eb2de3875..cc2406e43f 100644 --- a/benchmarks/bench_sampling.py +++ b/benchmarks/bench_sampling.py @@ -220,6 +220,86 @@ def main(): f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) + print("---") + print("top-p renorm probs") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + torch.manual_seed(42) + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for p in [0.1, 0.5, 0.9]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + measurements = bench_gpu_time( + lambda: flashinfer.sampling.top_p_renorm_probs(probs, p), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = probs.numel() * probs.element_size() * 2 + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + + print("---") + print("top-k renorm probs") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + torch.manual_seed(42) + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for k in [10, 100, 1000, 5000]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + measurements = bench_gpu_time( + lambda: flashinfer.sampling.top_k_renorm_probs(probs, k), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = probs.numel() * probs.element_size() * 2 + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + + print("---") + print("top-k mask logits") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + torch.manual_seed(42) + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for k in [10, 100, 1000, 5000]: + logits = distrib((batch_size, vocab_size), device="cuda") + measurements = bench_gpu_time( + lambda: flashinfer.sampling.top_k_mask_logits(logits, k), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = logits.numel() * logits.element_size() * 2 + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + if __name__ == "__main__": main() diff --git a/benchmarks/bench_softmax.py b/benchmarks/bench_softmax.py new file mode 100755 index 0000000000..6da8dc9fcb --- /dev/null +++ b/benchmarks/bench_softmax.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +Benchmark script comparing torch.softmax vs flashinfer.softmax performance. +Creates a heatmap showing speedup across different batch sizes and hidden dimensions. +""" + +import numpy as np +import torch +import matplotlib.pyplot as plt +import seaborn as sns +from typing import List, Tuple +import flashinfer +from flashinfer.testing.utils import bench_gpu_time + + +@torch.inference_mode() +def benchmark_torch_softmax(logits: torch.Tensor) -> float: + """Benchmark torch's native softmax.""" + measurements = bench_gpu_time( + lambda: torch.softmax(logits, dim=-1), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + return np.median(measurements) + + +@torch.inference_mode() +def benchmark_flashinfer_softmax(logits: torch.Tensor) -> float: + """Benchmark flashinfer's softmax.""" + measurements = bench_gpu_time( + lambda: flashinfer.sampling.softmax(logits, temperature=None, enable_pdl=False), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + return np.median(measurements) + + +def run_benchmark( + batch_sizes: List[int], hidden_sizes: List[int] +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Run benchmarks for all combinations of batch_size and hidden_size. + + Returns: + torch_times: 2D array of torch softmax times (ms) + flashinfer_times: 2D array of flashinfer softmax times (ms) + speedups: 2D array of speedup ratios (torch_time / flashinfer_time) + """ + n_batch = len(batch_sizes) + n_hidden = len(hidden_sizes) + + torch_times = np.zeros((n_batch, n_hidden)) + flashinfer_times = np.zeros((n_batch, n_hidden)) + speedups = np.zeros((n_batch, n_hidden)) + + print("Running benchmarks...") + print("=" * 100) + print( + f"{'Batch Size':<12} {'Hidden Size':<12} {'Torch (ms)':<15} " + f"{'FlashInfer (ms)':<18} {'Speedup':<10} {'Bandwidth (GB/s)':<18}" + ) + print("=" * 100) + + for i, batch_size in enumerate(batch_sizes): + for j, hidden_size in enumerate(hidden_sizes): + # Generate random logits + torch.manual_seed(42) + logits = torch.randn( + batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Benchmark torch softmax + torch_time_ms = benchmark_torch_softmax(logits) + torch_times[i, j] = torch_time_ms + + # Benchmark flashinfer softmax + flashinfer_time_ms = benchmark_flashinfer_softmax(logits) + flashinfer_times[i, j] = flashinfer_time_ms + + # Calculate speedup + speedup = torch_time_ms / flashinfer_time_ms + speedups[i, j] = speedup + + # Calculate effective bandwidth (read + write) + io_bytes = logits.numel() * logits.element_size() * 2 + bandwidth_gb_s = io_bytes * 1e-6 / flashinfer_time_ms + + print( + f"{batch_size:<12} {hidden_size:<12} {torch_time_ms:<15.4f} " + f"{flashinfer_time_ms:<18.4f} {speedup:<10.2f}x {bandwidth_gb_s:<18.2f}" + ) + + print("=" * 100) + return torch_times, flashinfer_times, speedups + + +def plot_heatmap( + speedups: np.ndarray, + batch_sizes: List[int], + hidden_sizes: List[int], + save_path: str = "softmax_speedup_heatmap.png", +): + """Create and save a heatmap of speedup values.""" + # Create figure + fig, ax = plt.subplots(figsize=(12, 8)) + + # Create heatmap + sns.heatmap( + speedups, + annot=True, + fmt=".2f", + cmap="RdYlGn", + center=1.0, + cbar_kws={"label": "Speedup (x)"}, + xticklabels=[f"{h // 1000}K" for h in hidden_sizes], + yticklabels=batch_sizes, + ax=ax, + vmin=0.5, # Adjust color scale + vmax=max(3.0, speedups.max()), # Dynamic upper bound + ) + + ax.set_xlabel("Hidden Size", fontsize=12, fontweight="bold") + ax.set_ylabel("Batch Size", fontsize=12, fontweight="bold") + ax.set_title( + "FlashInfer Softmax Speedup vs PyTorch (Higher is Better)", + fontsize=14, + fontweight="bold", + pad=20, + ) + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches="tight") + print(f"\nHeatmap saved to: {save_path}") + + # Also create a performance comparison plot + _, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) + + # Plot 2: Speedup trends across batch sizes + for j, hidden_size in enumerate(hidden_sizes): + ax2.plot( + batch_sizes, + speedups[:, j], + marker="o", + label=f"Hidden={hidden_size // 1000}K", + linewidth=2, + ) + + ax2.set_xlabel("Batch Size", fontsize=12, fontweight="bold") + ax2.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold") + ax2.set_title("Speedup vs Batch Size", fontsize=13, fontweight="bold") + ax2.set_xscale("log", base=2) + ax2.grid(True, alpha=0.3) + ax2.legend(fontsize=9) + ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup") + + # Plot 1: Speedup trends across hidden sizes + for i, batch_size in enumerate(batch_sizes[::2]): # Sample every other batch size + idx = i * 2 + ax1.plot( + [h // 1000 for h in hidden_sizes], + speedups[idx, :], + marker="s", + label=f"Batch={batch_size}", + linewidth=2, + ) + + ax1.set_xlabel("Hidden Size (K)", fontsize=12, fontweight="bold") + ax1.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold") + ax1.set_title("Speedup vs Hidden Size", fontsize=13, fontweight="bold") + ax1.grid(True, alpha=0.3) + ax1.legend(fontsize=9) + ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5) + + plt.tight_layout() + comparison_path = save_path.replace(".png", "_trends.png") + plt.savefig(comparison_path, dpi=300, bbox_inches="tight") + print(f"Trend plots saved to: {comparison_path}") + + +def main(): + """Main benchmark execution.""" + # Configuration + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + hidden_sizes = [32000, 64000, 128000, 256000] + + print("=" * 100) + print("FlashInfer vs PyTorch Softmax Benchmark") + print("=" * 100) + print(f"Batch sizes: {batch_sizes}") + print(f"Hidden sizes: {hidden_sizes}") + print(f"Device: {torch.cuda.get_device_name()}") + print("=" * 100) + print() + + # Run benchmarks + _, _, speedups = run_benchmark(batch_sizes, hidden_sizes) + + # Print summary statistics + print("\nSummary Statistics:") + print("=" * 100) + print(f"Average speedup: {np.mean(speedups):.2f}x") + print(f"Median speedup: {np.median(speedups):.2f}x") + print(f"Min speedup: {np.min(speedups):.2f}x") + print(f"Max speedup: {np.max(speedups):.2f}x") + print("=" * 100) + + # Generate heatmap + plot_heatmap(speedups, batch_sizes, hidden_sizes) + + print("\nBenchmark complete!") + + +if __name__ == "__main__": + main() diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 6b134630cf..f3b188abec 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -333,6 +333,7 @@ __global__ void OnlineSoftmaxFusedKernel(DType* logits, DType* output, DType* te float running_max = -cuda::std::numeric_limits::infinity(); float running_denominator = 0.0f; + float threadlocal_running_denominator = 0.0f; #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); @@ -368,39 +369,32 @@ __global__ void OnlineSoftmaxFusedKernel(DType* logits, DType* output, DType* te } __syncthreads(); block_max = temp_storage.shared_state.max_val; - // if block_max is -inf, then this block contains all -inf values, so we can skip updating if (!isinf(block_max)) { - float thread_sum = 0.0f; + float threadlocal_sum = 0.0f; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - thread_sum += __expf(logits_vec[j] - block_max); - } - - float block_sum = - cub::BlockReduce(temp_storage.block_prim.reduce).Sum(thread_sum); - __syncthreads(); - - if (tx == 0) { - float new_max = max(running_max, block_max); - running_denominator = running_denominator * __expf(running_max - new_max) + - block_sum * __expf(block_max - new_max); - running_max = new_max; - - temp_storage.shared_state.max_val = running_max; - temp_storage.shared_state.denominator = running_denominator; + threadlocal_sum += __expf(logits_vec[j] - block_max); } - __syncthreads(); - running_max = temp_storage.shared_state.max_val; - running_denominator = temp_storage.shared_state.denominator; + float new_max = max(running_max, block_max); + threadlocal_running_denominator = + threadlocal_running_denominator * __expf(running_max - new_max) + + threadlocal_sum * __expf(block_max - new_max); + running_max = new_max; } } + running_denominator = cub::BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_running_denominator); + if (tx == 0) { + temp_storage.shared_state.denominator = running_denominator; + } + __syncthreads(); + running_denominator = temp_storage.shared_state.denominator; + const float final_max = running_max; const float inv_denominator = 1.0f / running_denominator; - __syncthreads(); - // Pass 2: Normalize in place vec_t prob_vec; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { @@ -458,6 +452,7 @@ __global__ void OnlineSoftmaxMapKernel(DType* logits, PartialSoftmaxResult* part vec_t logits_vec; float running_max = -cuda::std::numeric_limits::infinity(); float running_denominator = 0.0f; + float threadlocal_running_denominator = 0.0f; #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); @@ -489,31 +484,27 @@ __global__ void OnlineSoftmaxMapKernel(DType* logits, PartialSoftmaxResult* part // if block_max is -inf, then this block contains all -inf values, so we can skip updating if (!isinf(block_max)) { - float thread_sum = 0.0f; + float threadlocal_sum = 0.0f; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - thread_sum += __expf(logits_vec[j] - block_max); - } - - float block_sum = - cub::BlockReduce(temp_storage.block_prim.reduce).Sum(thread_sum); - __syncthreads(); - - if (tx == 0) { - float new_max = max(running_max, block_max); - running_denominator = running_denominator * __expf(running_max - new_max) + - block_sum * __expf(block_max - new_max); - running_max = new_max; - - temp_storage.shared_state.max_val = running_max; - temp_storage.shared_state.denominator = running_denominator; + threadlocal_sum += __expf(logits_vec[j] - block_max); } - __syncthreads(); - running_max = temp_storage.shared_state.max_val; - running_denominator = temp_storage.shared_state.denominator; + float new_max = max(running_max, block_max); + threadlocal_running_denominator = + threadlocal_running_denominator * __expf(running_max - new_max) + + threadlocal_sum * __expf(block_max - new_max); + running_max = new_max; } } + running_denominator = cub::BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_running_denominator); + if (tx == 0) { + temp_storage.shared_state.denominator = running_denominator; + } + __syncthreads(); + running_denominator = temp_storage.shared_state.denominator; + if (tx == 0) { partial_results[bx * num_slices + by] = {running_max, running_denominator}; } @@ -887,6 +878,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* double pivot_1 = (pivot_0 + high) / 2; ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + ValueCount threadlocal_gt_pivot_0{0, 0}, threadlocal_gt_pivot_1{0, 0}; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -903,26 +895,27 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* probs_gt_pivot_1[j] = { (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + threadlocal_gt_pivot_0 += probs_gt_pivot_0[j]; + threadlocal_gt_pivot_1 += probs_gt_pivot_1[j]; } + } + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_gt_pivot_0); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - - aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_gt_pivot_1); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; if (aggregate_gt_pivot_0.count < k) { // case 1: pivot_0 accepted break; @@ -1000,6 +993,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* double pivot_1 = (pivot_0 + high) / 2; float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; + float threadlocal_aggregate_gt_pivot_0 = 0; + float threadlocal_aggregate_gt_pivot_1 = 0; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1012,24 +1007,26 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1[j]; } + } + aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot_0); + if (tx == 0) { + temp_storage.block_aggregate.value = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; - aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) - .template Sum(probs_gt_pivot_0); - if (tx == 0) { - temp_storage.block_aggregate.value = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; - - aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) - .template Sum(probs_gt_pivot_1); - if (tx == 0) { - temp_storage.block_aggregate.value = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.value; + aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot_1); + if (tx == 0) { + temp_storage.block_aggregate.value = aggregate_gt_pivot_1; } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.value; + if (aggregate_gt_pivot_0 < top_p) { // case 1: pivot_0 accepted break; @@ -1077,6 +1074,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp vec_t probs_vec; float aggregate_gt_pivot = 0; + float threadlocal_aggregate_gt_pivot = 0; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1088,15 +1086,16 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_gt_pivot[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0; + threadlocal_aggregate_gt_pivot += probs_gt_pivot[j]; } + } - aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.block_aggregate.value = aggregate_gt_pivot; - } - __syncthreads(); + aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot); + if (tx == 0) { + temp_storage.block_aggregate.value = aggregate_gt_pivot; } + __syncthreads(); float aggregate = 0; float q = temp_storage.block_aggregate.value; @@ -1187,6 +1186,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, double pivot_1 = (pivot_0 + high) / 2; ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + ValueCount threadlocal_aggregate_gt_pivot_0{0, 0}; + ValueCount threadlocal_aggregate_gt_pivot_1{0, 0}; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1203,26 +1204,27 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, probs_gt_pivot_1[j] = { (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1[j]; } + } + aggregate_gt_pivot_0 += + BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_aggregate_gt_pivot_0); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - aggregate_gt_pivot_0 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - - aggregate_gt_pivot_1 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; + aggregate_gt_pivot_1 += + BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_aggregate_gt_pivot_1); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; if (aggregate_gt_pivot_0.count < k && aggregate_gt_pivot_0.value < p) { // case 1: pivot_0 accepted break; @@ -1663,6 +1665,8 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; min_gt_low = high; max_le_high = low; + float threadlocal_aggregate_gt_pivot_0 = 0; + float threadlocal_aggregate_gt_pivot_1 = 0; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1682,18 +1686,19 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { max_le_high = max(max_le_high, probs_vec[j]); } + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1[j]; } - - aggregate_gt_pivot_0 += - BlockReduce(temp_storage.block_prim.reduce) - .template Sum(probs_gt_pivot_0); - __syncthreads(); - - aggregate_gt_pivot_1 += - BlockReduce(temp_storage.block_prim.reduce) - .template Sum(probs_gt_pivot_1); - __syncthreads(); } + aggregate_gt_pivot_0 = + BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot_0); + __syncthreads(); + aggregate_gt_pivot_1 = + BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot_1); + __syncthreads(); + min_gt_low = BlockReduce(temp_storage.block_prim.reduce) .Reduce(min_gt_low, MinReduceOp{}); __syncthreads(); @@ -1783,6 +1788,8 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType int aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; min_gt_low = high; max_le_high = low; + int threadlocal_aggregate_gt_pivot_0 = 0; + int threadlocal_aggregate_gt_pivot_1 = 0; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { logits_vec.fill(0); @@ -1803,18 +1810,20 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType if (logits_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { max_le_high = max(max_le_high, logits_vec[j]); } + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0_count[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1_count[j]; } + } + aggregate_gt_pivot_0 += + BlockReduce(temp_storage.block_prim.reduce_int) + .Sum(threadlocal_aggregate_gt_pivot_0); + __syncthreads(); - aggregate_gt_pivot_0 += - BlockReduce(temp_storage.block_prim.reduce_int) - .Sum(probs_gt_pivot_0_count); - __syncthreads(); + aggregate_gt_pivot_1 += + BlockReduce(temp_storage.block_prim.reduce_int) + .Sum(threadlocal_aggregate_gt_pivot_1); + __syncthreads(); - aggregate_gt_pivot_1 += - BlockReduce(temp_storage.block_prim.reduce_int) - .Sum(probs_gt_pivot_1_count); - __syncthreads(); - } min_gt_low = BlockReduce(temp_storage.block_prim.reduce) .Reduce(min_gt_low, MinReduceOp{}); @@ -1901,6 +1910,8 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; min_gt_low = high; max_le_high = low; + ValueCount threadlocal_aggregate_gt_pivot_0{0, 0}, + threadlocal_aggregate_gt_pivot_1{0, 0}; #pragma unroll 1 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1923,18 +1934,20 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { max_le_high = max(max_le_high, probs_vec[j]); } + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0_pair[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1_pair[j]; } + } + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_aggregate_gt_pivot_0); + __syncthreads(); - aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .template Sum(probs_gt_pivot_0_pair); - __syncthreads(); + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_aggregate_gt_pivot_1); + __syncthreads(); - aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .template Sum(probs_gt_pivot_1_pair); - __syncthreads(); - } min_gt_low = BlockReduce(temp_storage.block_prim.reduce) .Reduce(min_gt_low, MinReduceOp{}); From f566d49cec23db587f68556193244f8bd106bec5 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Thu, 6 Nov 2025 23:07:02 -0800 Subject: [PATCH 034/130] misc: Add XQA decode to microbenchmark for sm90 and sm120 (#2055) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description In #2001 , XQA decode kernels became available through `trtllm_batch_decode_with_kv_cache` on SM90 and SM120. Current PR adds the ability to benchmark through the microbenchmark. Example microbenchmark command and outputs before and after: ``` ### Before current PR: ## SM90 (H200) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 9.0. Skipping. [PERF] fa2 :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec [PERF] cudnn :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.519 TFLOPs/sec; achieved tb_per_sec 1.692 TB/sec ## SM120 (RTX 5090) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 12.0. Skipping. [PERF] fa2 :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.204 TFLOPs/sec; achieved tb_per_sec 1.027 TB/sec [PERF] cudnn :: median time 0.030 ms; std 0.000 ms; achieved tflops 8.943 TFLOPs/sec; achieved tb_per_sec 1.119 TB/sec ### After current PR: ## SM90 (H200) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [PERF] fa2 :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec [PERF] trtllm-gen-nati:: median time 0.019 ms; std 0.002 ms; achieved tflops 13.820 TFLOPs/sec; achieved tb_per_sec 1.729 TB/sec [PERF] cudnn :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.574 TFLOPs/sec; achieved tb_per_sec 1.698 TB/sec ## SM120 (RTX 5090) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [PERF] fa2 :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.121 TFLOPs/sec; achieved tb_per_sec 1.016 TB/sec [PERF] trtllm-gen-nati:: median time 0.034 ms; std 0.001 ms; achieved tflops 7.903 TFLOPs/sec; achieved tb_per_sec 0.989 TB/sec [PERF] cudnn :: median time 0.030 ms; std 0.001 ms; achieved tflops 9.020 TFLOPs/sec; achieved tb_per_sec 1.129 TB/sec ``` ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Standardized backend identifier to "trtllm-native" and expanded its support across benchmark routines and utilities. * Argument parsing now canonicalizes deprecated backend aliases and emits a deprecation warning when encountered. * **Documentation** * README and tool-facing messages updated to use the canonical backend name and include contextual notes about the change. --- benchmarks/README.md | 14 ++-- benchmarks/routines/attention.py | 65 ++++++++++++++----- .../routines/flashinfer_benchmark_utils.py | 24 ++++--- 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index f41d695cdc..e7e17156a4 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -117,7 +117,7 @@ The output CSV will contain detailed metrics including: | `--verbose`, `-v` | Print additional information (can be used multiple times for more verbosity, e.g. `-vv`) | | `--case_tag` | Optional tag for the test case, useful for annotating or filtering results in the output CSV. | | `--generate_repro_command`| If set, prints a reproducer command for the test case and stores it in the output CSV. | -| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cutlass, trtllm, trtllm-gen, trtllm-gen-native, cublas| +| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cutlass, trtllm, trtllm-gen, trtllm-native, cublas| ### Attention Flags | Flag | Description | @@ -213,14 +213,14 @@ Legend: - cutlass: CUTLASS - trtllm: TensorRT-LLM - trtllm-gen: TensorRT-LLM (generic wrapper) -- trtllm-gen-native: TensorRT-LLM (native API) +- trtllm-native: TensorRT-LLM (native API) --> | Routine | 7.5 | 8.0 | 8.6 | 8.9 | 9.0 | 10.0 | 10.3 | 12.0 | |---------|-----|-----|-----|-----|-----|-------|-------|-------| -| **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native | fa2, fa2_tc, cudnn | -| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen, trtllm-gen-native | fa2, cudnn, trtllm-gen, trtllm-gen-native | fa2, cudnn | -| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass, trtllm-gen-native | fa2, cudnn, cutlass, trtllm-gen-native | fa2, cudnn | -| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-gen-native | fa2, cutlass, trtllm-gen-native | fa2 | +| **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn | +| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn | +| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn | +| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-native | fa2, cutlass, trtllm-native | fa2 | | **gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | | | **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | | | **bmm_fp8** | | | | cudnn, cublas | cudnn, cublas | cudnn, cublas, cutlass | cudnn, cublas, cutlass | cudnn, cublas | @@ -238,4 +238,4 @@ Backend Legend: - cutlass: CUTLASS - trtllm: TensorRT-LLM - trtllm-gen: TensorRT-LLM -- trtllm-gen-native: TensorRT-LLM (out-of-wrapper) +- trtllm-native: TensorRT-LLM (out-of-wrapper) diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 9dd2442eed..e88b176f13 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -19,6 +19,30 @@ ) +def normalize_backends(backends): + """ + Normalize backend names planned for deprecation and print warnings. + Currently: + - Replaces deprecated 'trtllm-gen-native' with 'trtllm-native'. + + Args: + backends: List of backend names + + Returns: + List of normalized backend names + """ + normalized = [] + for backend in backends: + if backend == "trtllm-gen-native": + print( + "[WARNING] Backend name 'trtllm-gen-native' has been renamed to 'trtllm-native' and will be removed in a future release. " + ) + normalized.append("trtllm-native") + else: + normalized.append(backend) + return normalized + + def run_attention_test(args): """ Run an attention test. @@ -66,7 +90,8 @@ def parse_attention_args(line, parser): "cudnn", "cutlass", "trtllm-gen", - "trtllm-gen-native", + "trtllm-native", + "trtllm-gen-native", # Deprecated, will be removed in future ], help="Kernel backends to test. Default: fa2", ) @@ -151,6 +176,10 @@ def parse_attention_args(line, parser): ) args = parser.parse_args(line) + + # Normalize backend names (handle deprecated names) + args.backends = normalize_backends(args.backends) + if args.verbose >= 1: print(f"[INFO] {args = }") return args @@ -185,7 +214,7 @@ def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len def testBatchDecodeWithPagedKVCacheWrapper(args): """ Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API. - Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native backends. + Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native backends. This test: 1. Creates paged KV cache and query tensors @@ -490,7 +519,7 @@ def run_backend_wrapper(backend): batch_offsets_q=ragged_q, batch_offsets_o=ragged_q, ) - elif backend == "trtllm-gen-native": + elif backend == "trtllm-native": return flashinfer.decode.trtllm_batch_decode_with_kv_cache( query=q.contiguous(), kv_cache=kv_cache, @@ -614,7 +643,7 @@ def run_backend_wrapper(backend): def testBatchPrefillWithPagedKVCacheWrapper(args): """ Test BatchPrefillWithPagedKVCacheWrapper API and equivalent cuDNN API. - Supports fa2, fa3, trtllm-gen, trtllm-gen-native, and cudnn backends. + Supports fa2, fa3, trtllm-gen, trtllm-native, and cudnn backends. This test: 1. Creates paged KV cache and query tensors for prefill @@ -697,13 +726,13 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): remove_trtllm = True if remove_trtllm: backends.remove("trtllm-gen") - if "trtllm-gen-native" in backends: + if "trtllm-native" in backends: remove_trtllm_native = False if not causal: - print("[INFO] trtllm-gen-native backend currently requires causal = True") + print("[INFO] trtllm-native backend currently requires causal = True") remove_trtllm_native = True if remove_trtllm_native: - backends.remove("trtllm-gen-native") + backends.remove("trtllm-native") if "cutlass" in backends: print("[INFO] CUTLASS backend does not support prefill. Skipping.") @@ -955,7 +984,7 @@ def run_backend_wrapper(backend): batch_offsets_q=q_indptr, batch_offsets_o=q_indptr, )[0] - elif backend == "trtllm-gen-native": + elif backend == "trtllm-native": return flashinfer.prefill.trtllm_batch_context_with_kv_cache( query=q, kv_cache=kv_cache, @@ -1178,21 +1207,21 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): remove_trtllm = True if remove_trtllm: backends.remove("trtllm-gen") - if "trtllm-gen-native" in backends: + if "trtllm-native" in backends: remove_trtllm_native = False if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, ]: - print("[INFO] trtllm-gen-native backend does not support FP8. Skipping.") + print("[INFO] trtllm-native backend does not support FP8. Skipping.") remove_trtllm_native = True if not (head_dim_qk == 192 and head_dim_vo == 128): print( - "[INFO] trtllm-gen-native backend requires head_dim_qk == 192 and head_dim_vo == 128" + "[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128" ) remove_trtllm_native = True if remove_trtllm_native: - backends.remove("trtllm-gen-native") + backends.remove("trtllm-native") if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") @@ -1404,7 +1433,7 @@ def run_backend_wrapper(backend): batch_offsets_stats=batch_offsets_stats, is_cuda_graph_compatible=True, )[0] - elif backend == "trtllm-gen-native": + elif backend == "trtllm-native": return flashinfer.prefill.trtllm_ragged_attention_deepseek( query=q, key=k, @@ -1538,7 +1567,7 @@ def run_backend_wrapper(backend): def testBatchMLAPagedAttentionWrapper(args): """ Test BatchMLAPagedAttentionWrapper and equivalent APIs. - Supports fa2, fa3, cutlass, and trtllm-gen-native. + Supports fa2, fa3, cutlass, and trtllm-native. This test: 1. Creates paged query and key-value cache tensors @@ -1634,15 +1663,15 @@ def testBatchMLAPagedAttentionWrapper(args): remove_cutlass = True if remove_cutlass: backends.remove("cutlass") - if "trtllm-gen-native" in backends: + if "trtllm-native" in backends: remove_trtllm_native = False if page_size not in [32, 64]: print( - "[INFO] trtllm-gen-native backend only supports page size 32 or 64. Skipping." + "[INFO] trtllm-native backend only supports page size 32 or 64. Skipping." ) remove_trtllm_native = True if remove_trtllm_native: - backends.remove("trtllm-gen-native") + backends.remove("trtllm-native") if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") return res @@ -1807,7 +1836,7 @@ def run_backend_wrapper(backend): page_table=block_tables, return_lse=False, ) - if backend == "trtllm-gen-native": + elif backend == "trtllm-native": return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=q.unsqueeze(1), kv_cache=kv_cache.unsqueeze(1), diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 3836e03630..8798f8340f 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -162,43 +162,47 @@ def dtype_str_to_torch_dtype(dtype_str): routine_cc_to_supported_backends = { # ATTENTION "BatchDecodeWithPagedKVCacheWrapper": { + # NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache "7.5": ["fa2"], "8.0": ["fa2", "fa2_tc", "cudnn"], "8.6": ["fa2", "fa2_tc", "cudnn"], "8.9": ["fa2", "fa2_tc", "cudnn"], - "9.0": ["fa2", "fa2_tc", "cudnn"], - "10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"], - "10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"], - "12.0": ["fa2", "fa2_tc", "cudnn"], + "9.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"], + "10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"], + "10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"], + "12.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"], }, "BatchPrefillWithPagedKVCacheWrapper": { + # NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache "7.5": [], "8.0": ["fa2", "cudnn"], "8.6": ["fa2", "cudnn"], "8.9": ["fa2", "cudnn"], "9.0": ["fa2", "fa3", "cudnn"], - "10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-gen-native"], - "10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-gen-native"], + "10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"], + "10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"], "12.0": ["fa2", "cudnn"], }, "BatchPrefillWithRaggedKVCacheWrapper": { + # NOTE: trtllm-native calls trtllm_ragged_attention_deepseek "7.5": [], "8.0": ["fa2", "cudnn"], "8.6": ["fa2", "cudnn"], "8.9": ["fa2", "cudnn"], "9.0": ["fa2", "fa3", "cudnn"], - "10.0": ["fa2", "cudnn", "cutlass", "trtllm-gen-native"], - "10.3": ["fa2", "cudnn", "cutlass", "trtllm-gen-native"], + "10.0": ["fa2", "cudnn", "cutlass", "trtllm-native"], + "10.3": ["fa2", "cudnn", "cutlass", "trtllm-native"], "12.0": ["fa2", "cudnn"], }, "BatchMLAPagedAttentionWrapper": { + # NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache_mla "7.5": [], "8.0": ["fa2"], "8.6": ["fa2"], "8.9": ["fa2"], "9.0": ["fa2", "fa3"], - "10.0": ["fa2", "cutlass", "trtllm-gen-native"], - "10.3": ["fa2", "cutlass", "trtllm-gen-native"], + "10.0": ["fa2", "cutlass", "trtllm-native"], + "10.3": ["fa2", "cutlass", "trtllm-native"], "12.0": ["fa2"], }, # GEMM From 36d24632fbacc2a08933105e4a45fbf4cdbfa551 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Fri, 7 Nov 2025 01:59:51 -0800 Subject: [PATCH 035/130] test: Skip unsupported SM Archs for newly added trtllm MoE test (#2060) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description `tests/moe/test_trtllm_gen_routed_fused_moe.py` was newly added in #2049, but does not have an SM arch check, which causes unit test failures on non SM10X devices. Current PR adds skips ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Added GPU compute capability checks to MOE tests. Tests are now skipped on unsupported hardware, requiring SM100 or SM103 GPUs to execute. --- tests/moe/test_trtllm_gen_routed_fused_moe.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index 8bda03e971..be39bda225 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -36,6 +36,8 @@ routing_reference_topk, ) +from flashinfer.utils import get_compute_capability + @pytest.mark.parametrize("num_tokens", [1, 8, 1024]) @pytest.mark.parametrize("hidden_size", [1024, 2048, 3072, 4096]) @@ -60,6 +62,9 @@ def test_trtllm_gen_routed_fused_moe( routing_method_type: RoutingMethodType, quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], ): + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] not in [10]: + pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") torch.manual_seed(42) device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) From 3cb8f9ab3fc851392d12fd6a22f242b25ffd266d Mon Sep 17 00:00:00 2001 From: Jimmy Zhou <79552142+jimmyzho@users.noreply.github.com> Date: Fri, 7 Nov 2025 12:17:57 -0500 Subject: [PATCH 036/130] feat: suitable_auto_backends to prune auto backends, bmm_fp8 refactor, heuristic_func intake (#2029) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Improvements** * Expanded FP8 BMM backend support with explicit Cutlass SM10x/SM12x handling, safer fallbacks (no unconditional hard failures), and richer auto-selection that exposes viable backends and respects device capabilities. * Added heuristic-driven backend preference for auto and cutlass paths. * **Refactor** * Backend gating reorganized into per-backend capability checks, a shared problem-size pre-check, and heuristic selection; decorator now exposes suitable_auto_backends and capability extraction. * **Tests** * Added tests validating auto backend discovery and heuristic ordering. --- flashinfer/gemm.py | 142 +++++++++++++++++++++++++-------- flashinfer/utils.py | 102 ++++++++++++++++------- tests/utils/test_decorators.py | 91 +++++++++++++++++++++ 3 files changed, 274 insertions(+), 61 deletions(-) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 9f00cc6e25..fc4c1b8885 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -364,31 +364,15 @@ def fp8_gemm_sm100( runner_names: List[str], ) -> None: runners = [] - # No e5m2 for cutlass - is_e5m2 = a.dtype == torch.float8_e5m2 or b.dtype == torch.float8_e5m2 - is_sm_supported = _match_sm_version(a.device, ["100", "103", "110"]) - is_sm120_supported = _match_sm_version(a.device, ["120", "121"]) - - if "cutlass" in runner_names and not is_e5m2: - if is_sm_supported: - runners.append( - get_gemm_sm100_module_cutlass_fp8().cutlass_fp8_gemm_runner() - ) - elif is_sm120_supported: - k_dim = a.shape[-1] if a.dim() == 2 else a.shape[2] - if k_dim >= 128: - runners.append( - get_gemm_sm120_module_cutlass_fp8().cutlass_fp8_gemm_runner() - ) + if "cutlass_sm10x" in runner_names: + runners.append(get_gemm_sm100_module_cutlass_fp8().cutlass_fp8_gemm_runner()) + if "cutlass_sm12x" in runner_names: + runners.append(get_gemm_sm120_module_cutlass_fp8().cutlass_fp8_gemm_runner()) if "cublas" in runner_names: runners.append(get_gemm_module().cublas_fp8_gemm_runner()) - if CUDNN_AVAILABLE and "cudnn" in runner_names: + if "cudnn" in runner_names: runners.append(_cudnn_gemm_fp8_runner()) - - if len(runners) == 0: - major, minor = get_compute_capability(torch.device("cuda")) - raise ValueError(f"No valid runner found for current device sm{major}{minor}") - + assert runners, "No suitable runners found" tuner = AutoTuner.get() a_tensor_index = 0 out_tensor_index = 4 @@ -2013,6 +1997,101 @@ def mm_fp4( return out +@supported_compute_capability([89, 90, 100, 103, 120, 121]) +def _cudnn_bmm_fp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + _check_cudnn_availability() + return True + + +@supported_compute_capability([89, 90, 100, 103, 120, 121]) +def _cublas_bmm_fp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + return True + + +@supported_compute_capability([100, 103, 110, 120, 121]) +def _cutlass_bmm_fp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: + raise ValueError("e5m2 is not supported for bmm_fp8 with cutlass backend") + return True + + +def _check_bmm_fp8_problem_size( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + _validate_fp8_output_dtype(dtype) + return True + + +def _heuristic_func_bmm_fp8( + suitable_backends: List[str], + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", +): + # No e5m2 for cutlass + is_e5m2 = A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2 + is_sm_supported = _match_sm_version(A.device, ["100", "103", "110"]) + is_sm120_supported = _match_sm_version(A.device, ["120", "121"]) + + # preserve order of ["cudnn", "cublas", "cutlass"] + heuristic_backends = [] + if "cutlass" in suitable_backends and not is_e5m2: + if is_sm_supported: + heuristic_backends.append("cutlass_sm10x") + elif is_sm120_supported: + k_dim = A.shape[-1] if A.dim() == 2 else A.shape[2] + if k_dim >= 128: + heuristic_backends.append("cutlass_sm12x") + if "cublas" in suitable_backends: + heuristic_backends.append("cublas") + if CUDNN_AVAILABLE and "cudnn" in suitable_backends: + heuristic_backends.append("cudnn") + return heuristic_backends + + +@backend_requirement( + { + "cudnn": _cudnn_bmm_fp8_requirement, + "cublas": _cublas_bmm_fp8_requirement, + "cutlass": _cutlass_bmm_fp8_requirement, + }, + common_check=_check_bmm_fp8_problem_size, + heuristic_func=_heuristic_func_bmm_fp8, +) def bmm_fp8( A: torch.Tensor, B: torch.Tensor, @@ -2077,7 +2156,6 @@ def bmm_fp8( >>> out.dtype torch.bfloat16 """ - _validate_fp8_output_dtype(dtype) if out is None: out = torch.empty( @@ -2090,18 +2168,16 @@ def bmm_fp8( "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device ) - if backend == "cudnn": - backends = ["cudnn"] - elif backend == "cublas": - backends = ["cublas"] + if backend == "auto": + backends = bmm_fp8.suitable_auto_backends elif backend == "cutlass": - if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: - raise ValueError("e5m2 is not supported for cutlass backend") - backends = ["cutlass"] - elif backend == "auto": - backends = ["cutlass", "cublas", "cudnn"] + backends = _heuristic_func_bmm_fp8( + ["cutlass"], A, B, A_scale, B_scale, dtype, out, backend + ) + elif backend == "cudnn" and CUDNN_AVAILABLE: + backends = ["cudnn"] else: - raise ValueError(f"Unsupported backend: {backend}") + backends = [backend] fp8_gemm_sm100(A, B, A_scale, B_scale, out, workspace_buffer, backends) return out diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 3aae147896..771d616380 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -854,7 +854,9 @@ def is_cc_supported(cc): def backend_requirement( - backend_checks: Dict[str, Callable], common_check: Optional[Callable] = None + backend_checks: Dict[str, Callable], + common_check: Optional[Callable] = None, + heuristic_func: Optional[Callable] = None, ) -> Callable: """ Decorator to enforce backend and problem size requirements for kernel functions. @@ -1018,6 +1020,47 @@ def has_backend(backend: str) -> bool: # Whether the given backend exists in the API return backend in backend_checks + def suitable_auto_backends(cc, *args, **kwargs): + if common_check is not None and not common_check(*args, **kwargs): + return False + suitable_backends = [] + # Check for each backend support + for backend in backend_checks: + req_checker = backend_checks[backend] + try: + if req_checker( + *args, **kwargs + ) and req_checker.is_compute_capability_supported(cc): + suitable_backends.append(backend) + except ValueError: + continue + # If a heuristic function is provided, filter the suitable backends based on the heuristic function + if heuristic_func is not None: + suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) + if not suitable_backends: + return False + wrapper.suitable_auto_backends = suitable_backends + return True + + def _get_capability(*args, **kwargs): + capability = None + # Find the first tensor argument. + # Assume all tensors are on the same device/capability. + # We could consider check all tensors at a performance cost. + tensor_arg = None + all_args = args + tuple(kwargs.values()) + for value in all_args: + if isinstance(value, torch.Tensor): + tensor_arg = value + break + + if tensor_arg is not None: + # Get compute capability from the first tensor + # Assume all tensors are on the same device/capability + major, minor = get_compute_capability(tensor_arg.device) + capability = major * 10 + minor + return capability + # @brief: Wrapper function that calls the orignal, decorated function, after applying a number of checks. # @note that here we manually apply defaults to the arguments in the wrapper function when doing validation. @functools.wraps(func) @@ -1034,45 +1077,48 @@ def wrapper(*args, **kwargs): bound_args.apply_defaults() # Convert to kwargs for validation functions kwargs_with_defaults = dict(bound_args.arguments) - backend = kwargs_with_defaults.get("backend") - - capability = None - # Find the first tensor argument. - # Assume all tensors are on the same device/capability. - # We could consider check all tensors at a performance cost. - tensor_arg = None - for value in kwargs_with_defaults.values(): - if isinstance(value, torch.Tensor): - tensor_arg = value - break - - if tensor_arg is not None: - # Get compute capability from the first tensor - # Assume all tensors are on the same device/capability - major, minor = get_compute_capability(tensor_arg.device) - capability = major * 10 + minor - + capability = _get_capability(*args, **kwargs) if not has_backend_choices() and common_check is None: raise ValueError( f"Invalid @backend_requirement decorator usage: no backend choices and no common_check for {func.__name__}" ) if has_backend_choices(): - if not is_backend_supported(backend, capability): - extra = f" with capability {capability}" if capability else "" - raise BackendSupportedError( - f"{func.__name__} does not support backend '{backend}'{extra}" - ) + if backend == "auto": + if not suitable_auto_backends( + capability, **kwargs_with_defaults + ): + raise BackendSupportedError( + f"No suitable auto backends found for {func.__name__}" + ) + else: + if not is_backend_supported(backend, capability): + extra = ( + f" with capability {capability}" if capability else "" + ) + raise BackendSupportedError( + f"{func.__name__} does not support backend '{backend}'{extra}" + ) + if not _is_problem_size_supported(**kwargs_with_defaults): + raise ValueError( + f"Problem size is not supported for {func.__name__}" + ) else: + # If the function doesnt have backends (i.e., there is only 1, implicit backend), run the following checks. if not is_compute_capability_supported(capability): raise BackendSupportedError( f"{func.__name__} does not support compute capability {capability}" ) - if not _is_problem_size_supported(**kwargs_with_defaults): - raise ValueError( - f"Problem size is not supported for {func.__name__}" - ) + if not _is_problem_size_supported(**kwargs_with_defaults): + raise ValueError( + f"Problem size is not supported for {func.__name__}" + ) + elif skip_check and heuristic_func is not None: + if kwargs.get("backend") == "auto": + # This needs to be called for heuristic function + capability = _get_capability(*args, **kwargs) + suitable_auto_backends(capability, *args, **kwargs) return func(*args, **kwargs) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 4f052019df..ebbda781fb 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -327,6 +327,97 @@ def my_kernel(x, backend="cudnn"): my_kernel(x_3d, backend="cudnn") +def test_suitable_auto_backends(): + """Test the suitable_auto_backends method.""" + if not torch.cuda.is_available(): + pytest.skip("Skipping CUDA tests (no GPU available)") + + x = torch.randn(1, 1, device="cuda") + major, minor = torch.cuda.get_device_capability(x.device) + actual_capability = major * 10 + minor + + @supported_compute_capability([80, 86, 89, 90, actual_capability]) + def _cutlass_check(x, backend): + return x.shape[0] > 10 + + @supported_compute_capability([75, 80, 86, 89, 90, actual_capability]) + def _cudnn_check(x, backend): + return x.shape[0] > 5 + + @backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check}) + def my_kernel(x, backend="auto"): + backends = my_kernel.suitable_auto_backends + if x.shape[0] > 5: + assert "cudnn" in backends + if x.shape[0] > 10: + assert "cutlass" in backends + return x * 2 + + x = torch.randn(6, 10, device="cuda") + result = my_kernel(x, backend="auto") + assert result.shape == x.shape + + with pytest.raises( + BackendSupportedError, match="No suitable auto backends found for my_kernel" + ): + x = torch.randn(1, 1, device="cuda") + my_kernel(x, backend="auto") + + +def test_heuristic_func(): + """Test the heuristic_func parameter.""" + if not torch.cuda.is_available(): + pytest.skip("Skipping CUDA tests (no GPU available)") + + x = torch.randn(1, 1, device="cuda") + major, minor = torch.cuda.get_device_capability(x.device) + actual_capability = major * 10 + minor + + @supported_compute_capability([80, 86, 89, 90, actual_capability]) + def _cutlass_check(x, backend): + return x.shape[0] > 10 + + @supported_compute_capability([75, 80, 86, 89, 90, actual_capability]) + def _cudnn_check(x, backend): + return x.shape[0] > 5 + + @supported_compute_capability([75, 80, 86, 89, 90, actual_capability]) + def _trtllm_check(x, backend): + return x.shape[0] > 0 + + def _heuristic_func(suitable_backends, x, backend): + # Cutlass fails check + assert "cutlass" not in suitable_backends + + # Example: out of the supported backends in suitable_backends, + # cudnn is preferred over trtllm when shape[0] > 5 + if x.shape[0] > 5: + return ["cudnn", "trtllm"] + else: + return ["trtllm", "cudnn"] + + @backend_requirement( + {"cutlass": _cutlass_check, "cudnn": _cudnn_check, "trtllm": _trtllm_check}, + heuristic_func=_heuristic_func, + ) + def my_kernel(x, backend="auto"): + if x.shape[0] > 5: + assert my_kernel.suitable_auto_backends[0] == "cudnn" + assert my_kernel.suitable_auto_backends[1] == "trtllm" + else: + assert my_kernel.suitable_auto_backends[0] == "trtllm" + assert my_kernel.suitable_auto_backends[1] == "cudnn" + return x * 2 + + x = torch.randn(8, 10, device="cuda") + result = my_kernel(x, backend="auto") + assert result.shape == x.shape + + x = torch.randn(2, 10, device="cuda") + result = my_kernel(x, backend="auto") + assert result.shape == x.shape + + def test_functools_wraps_preserves_metadata(): """Test that backend_requirement preserves function metadata with functools.wraps.""" From 20435b409f778773fe10bc048fe5e28bd9babb12 Mon Sep 17 00:00:00 2001 From: nv-yunzheq Date: Fri, 7 Nov 2025 09:35:07 -0800 Subject: [PATCH 037/130] update trtllm cutlass moe (#2020) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * SM90 scatter-based epilogue and broader SM100/SM120 MOE/GEMM coverage; new public enum for GEMM stages and explicit runner instantiations. * **Improvements** * New runtime controls and parameters exposed: dynamic CGA, swap-AB, swizzled-input SF, unpadded hidden-size, and per-GEMM-stage tactic counts; expanded tile/cluster shape options, finalize-epilogue fusion and fusion/swap-aware dispatch; increased runtime debug logging and profiling. * **Bug Fixes** * License/namespace/header cleanups, suppressed compiler warnings, tightened assertions. * **Tests** * MXFP8ร—MXFP4 test now permits SM120 devices. --------- Co-authored-by: Yong Wu Co-authored-by: Alex Yang --- .../cutlass_fused_moe_instantiation.cu | 12 +- .../cutlass_fused_moe_kernels.cuh | 871 +++++++------ ...shinfer_cutlass_fused_moe_sm100_binding.cu | 101 +- .../include/tensorrt_llm/common/cudaUtils.h | 3 + .../epilogue/fusion/sm90_visitor_scatter.hpp | 757 ++++++++++++ .../include/cutlass_extensions/gemm_configs.h | 358 +++--- .../cutlass_extensions/util/gather_tensor.hpp | 19 +- .../cutlass_kernels/cutlass_heuristic.cpp | 121 +- .../cutlass_kernels/cutlass_heuristic.h | 10 +- .../fpA_intB_gemm/fpA_intB_gemm_template.h | 49 +- .../fpA_intB_gemm_template_sm90.h | 23 +- .../launchers/fpA_intB_launcher_sm90.h | 4 +- .../launchers/fpA_intB_launcher_sm90.inl | 23 +- .../include/moe_gemm_kernels.h | 130 +- .../cutlass_kernels/include/moe_kernels.h | 152 ++- .../include/moe_util_kernels.h | 12 +- .../launchers/fused_moe_gemm_launcher_sm80.h | 22 +- .../fused_moe_gemm_launcher_sm80.inl | 32 +- .../launchers/moe_gemm_tma_ws_launcher.h | 17 +- .../launchers/moe_gemm_tma_ws_launcher.inl | 1081 +++++++++-------- .../moe_gemm_tma_ws_mixed_input_launcher.h | 33 +- .../moe_gemm_tma_ws_mixed_input_launcher.inl | 76 +- .../moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 22 +- .../moe_gemm/moe_gemm_kernels_bf16_fp4.cu | 2 +- .../moe_gemm/moe_gemm_kernels_bf16_fp8.cu | 22 +- .../moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 22 +- .../moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp16_fp4.cu | 2 +- .../moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp4_fp4.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp8_fp4.cu | 2 +- .../moe_gemm/moe_gemm_kernels_fp8_fp8.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp8_uint4.cu | 22 +- .../moe_gemm/moe_gemm_template_dispatch.h | 267 ++-- .../moe_gemm_template_dispatch_tma_ws.h | 239 ++-- ...emm_template_dispatch_tma_ws_mixed_dtype.h | 14 +- .../moe_gemm_tma_warp_specialized_input.cu | 123 +- .../moe_tma_warp_specialized_traits.h | 38 +- flashinfer/fused_moe/core.py | 21 +- .../jit/gemm/cutlass/generate_kernels.py | 240 +++- tests/moe/test_trtllm_cutlass_fused_moe.py | 4 +- 44 files changed, 3272 insertions(+), 1828 deletions(-) create mode 100644 csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu index f20729f163..6469b9a0cd 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu @@ -18,7 +18,6 @@ #include "moe_kernels.h" namespace tensorrt_llm::kernels::cutlass_kernels { -// ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; #ifdef ENABLE_BF16 @@ -38,6 +37,7 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>; #endif #endif #ifdef ENABLE_FP4 @@ -54,4 +54,12 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, _ template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>; #endif #endif -}; // namespace tensorrt_llm::kernels::cutlass_kernels + +// Explicit instantiations for finalizeMoeRoutingKernelLauncher to ensure +// symbols are emitted in the JIT library for common data types. +INSTANTIATE_FINALIZE_MOE_ROUTING(half, half, half); +INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float); +#ifdef ENABLE_BF16 +INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16); +#endif +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 85a77d7283..465241546d 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -284,6 +284,7 @@ void buildMinLatencyActiveExpertMaps( num_tokens, experts_per_token, start_expert, end_expert, num_experts_per_node, smart_routing, cluster_rank, cluster_size, num_experts_smem); } + template __global__ void fusedBuildExpertMapsSortFirstTokenKernel( int const* const token_selected_experts, int* const permuted_row_to_unpermuted_row, @@ -1007,7 +1008,8 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id, int64_t elem_idx, int64_t num_cols, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) { + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, + bool const swizzled_input_sf = true) { static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; // We need to offset into the scaling factors for just this expert @@ -1027,12 +1029,25 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, QuantizationSFLayout::SWIZZLED_128x4); if (sf_out) { if (input_sf) { - auto const sf_in = cvt_quant_get_sf_out_offset( - std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols / VecSize, const_cast(input_sf), - QuantizationSFLayout::SWIZZLED_128x4); - *sf_out = *sf_in; + if (swizzled_input_sf) { + auto const sf_in = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, + const_cast(input_sf), + QuantizationSFLayout::SWIZZLED_128x4); + *sf_out = *sf_in; + } else { + auto const sf_in = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, + const_cast(input_sf), + QuantizationSFLayout::LINEAR); + *sf_out = *sf_in; + } } else { *sf_out = 0x00; } @@ -1075,18 +1090,25 @@ __device__ void setupFP4BlockScalingFactors( TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* weight_block_scale, int64_t num_tokens_before_expert) { - assert(layout_info.fpX_block_scaling_factors_stride_A); - assert(layout_info.fpX_block_scaling_factors_stride_B); - - // M & N swapped for transpose - auto stride_a_ptr = reinterpret_cast( - layout_info.fpX_block_scaling_factors_stride_A); - auto stride_b_ptr = reinterpret_cast( - layout_info.fpX_block_scaling_factors_stride_B); - stride_a_ptr[expert] = BSConfig::tile_atom_to_shape_SFB( - cute::make_shape((int)gemm_n, (int)gemm_m, (int)gemm_k, (int)1)); - stride_b_ptr[expert] = BSConfig::tile_atom_to_shape_SFA( - cute::make_shape((int)gemm_n, (int)gemm_m, (int)gemm_k, (int)1)); + assert(layout_info.fpX_block_scaling_factors_stride_act); + assert(layout_info.fpX_block_scaling_factors_stride_weight); + + auto stride_act_ptr = reinterpret_cast( + layout_info.fpX_block_scaling_factors_stride_act); + auto stride_weight_ptr = reinterpret_cast( + layout_info.fpX_block_scaling_factors_stride_weight); + if (layout_info.swap_ab) { + // M & N swapped for transpose + stride_act_ptr[expert] = BSConfig::tile_atom_to_shape_SFB( + cute::make_shape((int)gemm_n, (int)gemm_m, (int)gemm_k, (int)1)); + stride_weight_ptr[expert] = BSConfig::tile_atom_to_shape_SFA( + cute::make_shape((int)gemm_n, (int)gemm_m, (int)gemm_k, (int)1)); + } else { + stride_act_ptr[expert] = BSConfig::tile_atom_to_shape_SFA( + cute::make_shape((int)gemm_m, (int)gemm_n, (int)gemm_k, (int)1)); + stride_weight_ptr[expert] = BSConfig::tile_atom_to_shape_SFB( + cute::make_shape((int)gemm_m, (int)gemm_n, (int)gemm_k, (int)1)); + } // This assert validates our current assumption that A&B can be safely transposed without needing // to modify @@ -1099,30 +1121,51 @@ __device__ void setupFP4BlockScalingFactors( std::is_same_v ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - layout_info.fpX_block_scaling_factors_A[expert] = + layout_info.fpX_block_scaling_factors_act[expert] = fp4_act_flat + getOffsetActivationSF(expert, num_tokens_before_expert, gemm_k, scaling_type); - layout_info.fpX_block_scaling_factors_B[expert] = + layout_info.fpX_block_scaling_factors_weight[expert] = weight_block_scale + getOffsetWeightSF(expert, gemm_n, gemm_k, scaling_type); } __device__ void computeTmaWarpSpecializedInputStrides( TmaWarpSpecializedGroupedGemmInput& layout_info, int gemm_m, int gemm_n, int gemm_k, int64_t out_idx) { - layout_info.stride_a[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::StrideA{}, cute::make_shape(gemm_m, gemm_k, 1)); - layout_info.stride_b[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::StrideB{}, cute::make_shape(gemm_n, gemm_k, 1)); + if (layout_info.swap_ab) { + reinterpret_cast( + layout_info.stride_act)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideB{}, + cute::make_shape(gemm_m, gemm_k, 1)); + reinterpret_cast( + layout_info.stride_weight)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideA{}, + cute::make_shape(gemm_n, gemm_k, 1)); + } else { + reinterpret_cast( + layout_info.stride_act)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideA{}, + cute::make_shape(gemm_m, gemm_k, 1)); + reinterpret_cast( + layout_info.stride_weight)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideB{}, + cute::make_shape(gemm_n, gemm_k, 1)); + } if (layout_info.stride_c) { + // TODO Enable 1xN bias matrix as C assert(false && "CUTLASS does not support a 1xN bias"); - // layout_info.stride_c[out_idx] = cute::make_stride(0, cute::Int<1>{}, 0); - layout_info.stride_c[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::StrideC{}, cute::make_shape(1, gemm_n, 1)); } if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info.default_epilogue.stride_d[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD{}, - cute::make_shape(gemm_n, gemm_m, 1)); + if (layout_info.swap_ab) { + reinterpret_cast( + layout_info.stride_d)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideD_T{}, + cute::make_shape(gemm_n, gemm_m, 1)); + } else { + reinterpret_cast( + layout_info.stride_d)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideD{}, + cute::make_shape(gemm_m, gemm_n, 1)); + } } if (layout_info.int4_groupwise_params.enabled) { layout_info.int4_groupwise_params.stride_s_a[out_idx] = cutlass::make_cute_packed_stride( @@ -1142,18 +1185,27 @@ __device__ void computeTmaWarpSpecializedInputPointers( TmaWarpSpecializedGroupedGemmInput& layout_info, int64_t gemm_m, int64_t gemm_n, int64_t gemm_k, int num_tokens_before_expert, int64_t expert, T const* in, WeightType const* weights, TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const* w4a8_weight_scale, - ScaleBiasType const* bias, OutputType* output, int64_t const out_idx) { + ScaleBiasType const* bias, OutputType* output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, int64_t const out_idx) { // The input prior to this contains K elements per token, with `num_tokens_before_expert` tokens - layout_info.ptr_a[out_idx] = safe_inc_ptr(in, num_tokens_before_expert * gemm_k); + layout_info.ptr_act[out_idx] = safe_inc_ptr(in, num_tokens_before_expert * gemm_k); // Each expert's weight matrix is a constant size NxK, get the matrix at index `expert` - layout_info.ptr_b[out_idx] = safe_inc_ptr(weights, expert * (gemm_n * gemm_k)); + layout_info.ptr_weight[out_idx] = safe_inc_ptr(weights, expert * (gemm_n * gemm_k)); if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { // The output prior to this contains N elements per token, with `num_tokens_before_expert` // tokens - layout_info.default_epilogue.ptr_d[out_idx] = - safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + layout_info.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + } + if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) { + layout_info.fused_finalize_epilogue.ptr_source_token_index[expert] = + permuted_row_to_unpermuted_row + num_tokens_before_expert; + layout_info.fused_finalize_epilogue.ptr_router_scales[expert] = + router_scales + num_tokens_before_expert; + if (layout_info.fused_finalize_epilogue.ptr_bias != nullptr) { + layout_info.fused_finalize_epilogue.ptr_bias[expert] = bias + gemm_n * expert; + } } if (layout_info.int4_groupwise_params.enabled) { // The group size of wfp4a16 is multiplied by 2 because each scale uses 1 byte instead of 2 @@ -1180,7 +1232,8 @@ __global__ void computeStridesTmaWarpSpecializedKernel( TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, - OutputType* gemm2_output) { + OutputType* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row) { // First, compute the global tid. We only need 1 thread per expert. int const expert = blockIdx.x * blockDim.x + threadIdx.x; if (expert >= num_experts_per_node) { @@ -1199,22 +1252,26 @@ __global__ void computeStridesTmaWarpSpecializedKernel( // M and N transposed since we are using the #tokens as the N dimension layout_info1.shape_info.problem_shapes[expert] = - TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm1_n, gemm_m, - gemm1_k); + TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape( + layout_info1.swap_ab ? gemm1_n : gemm_m, layout_info1.swap_ab ? gemm_m : gemm1_n, + gemm1_k); layout_info2.shape_info.problem_shapes[expert] = - TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm2_n, gemm_m, - gemm2_k); + TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape( + layout_info2.swap_ab ? gemm2_n : gemm_m, layout_info2.swap_ab ? gemm_m : gemm2_n, + gemm2_k); if (layout_info1.int4_groupwise_params.enabled) { layout_info1.int4_groupwise_params.shape.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::ProblemShapeInt:: - UnderlyingProblemShape(gemm1_n, gemm_m, gemm1_k); + UnderlyingProblemShape(layout_info1.swap_ab ? gemm1_n : gemm_m, + layout_info1.swap_ab ? gemm_m : gemm1_n, gemm1_k); } if (layout_info2.int4_groupwise_params.enabled) { layout_info2.int4_groupwise_params.shape.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::ProblemShapeInt:: - UnderlyingProblemShape(gemm2_n, gemm_m, gemm2_k); + UnderlyingProblemShape(layout_info2.swap_ab ? gemm2_n : gemm_m, + layout_info2.swap_ab ? gemm_m : gemm2_n, gemm2_k); } if (alpha_scale_flat1 && alpha_scale_flat2) { @@ -1241,9 +1298,6 @@ __global__ void computeStridesTmaWarpSpecializedKernel( setupIfSelected(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaledConfig{}, quant_params.mxfp8_mxfp4); -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif assert(gemm_m <= INT32_MAX); assert(gemm1_n > 0 && gemm1_n <= INT32_MAX); assert(gemm1_k > 0 && gemm1_k <= INT32_MAX); @@ -1256,142 +1310,15 @@ __global__ void computeStridesTmaWarpSpecializedKernel( layout_info1, gemm_m, gemm1_n, gemm1_k, num_tokens_before_expert, expert, gemm1_in, weights1, reinterpret_cast( quant_params.groupwise.fc1.weight_scales), - bias1, gemm1_output, expert); + bias1, gemm1_output, nullptr, nullptr, expert); computeTmaWarpSpecializedInputPointers( layout_info2, gemm_m, gemm2_n, gemm2_k, num_tokens_before_expert, expert, gemm2_in, weights2, reinterpret_cast( quant_params.groupwise.fc2.weight_scales), - bias2, gemm2_output, expert); -} - -template -__global__ void computeStridesTmaWarpSpecializedLowLatencyKernel( - TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, - int64_t gemm1_k, int64_t gemm2_n, int64_t gemm2_k, int64_t const num_experts_per_node, - T const* in1, T const* in2, WeightType const* weights1, WeightType const* weights2, - float const* alpha_scale_flat1, float const* alpha_scale_flat2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* output1, - OutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, - int start_expert) { - // First, compute the global tid. We only need 1 thread per expert. - int const expert = blockIdx.x * blockDim.x + threadIdx.x; - - if (expert >= num_experts_per_node) { - return; - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - - // Note: expert is used to calculate the offset of the input and output - // local_expert is used to calculate the offset of the weight - auto const num_tokens_before_expert = expert * num_tokens; - bool const is_active_expert = expert < *num_active_experts_per; - int const local_expert = is_active_expert ? active_expert_global_ids[expert] - start_expert : -1; - auto const gemm_m = is_active_expert ? num_tokens : 0; - - // M and N transposed since we are using the #tokens as the N dimension - layout_info1.shape_info.problem_shapes[expert] = - TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm1_n, gemm_m, - gemm1_k); - layout_info2.shape_info.problem_shapes[expert] = - TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm2_n, gemm_m, - gemm2_k); - - if (alpha_scale_flat1) { - assert(alpha_scale_flat2); - if (is_active_expert) { - layout_info1.alpha_scale_ptr_array[expert] = alpha_scale_flat1 + local_expert; - layout_info2.alpha_scale_ptr_array[expert] = alpha_scale_flat2 + local_expert; - } else { - layout_info1.alpha_scale_ptr_array[expert] = nullptr; - layout_info2.alpha_scale_ptr_array[expert] = nullptr; - } - } - - if (quant_params.fp4.fc1.weight_block_scale) { - setupFP4BlockScalingFactors( - layout_info1, expert, gemm_m, gemm1_n, gemm1_k, fp4_act_flat1, - quant_params.fp4.fc1.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc1 uses the same A input for all experts and the scaling - // factor B offsets from the local expert index - if (is_active_expert) { - layout_info1.fpX_block_scaling_factors_A[expert] = fp4_act_flat1; - layout_info1.fpX_block_scaling_factors_B[expert] = - quant_params.fp4.fc1.weight_block_scale + - getOffsetWeightSF(local_expert, gemm1_n, gemm1_k, - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info1.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info1.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - - if (quant_params.fp4.fc2.weight_block_scale) { - setupFP4BlockScalingFactors( - layout_info2, expert, gemm_m, gemm2_n, gemm2_k, fp4_act_flat2, - quant_params.fp4.fc2.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc2 scaling factor B offsets by the local expert index - if (is_active_expert) { - layout_info2.fpX_block_scaling_factors_B[expert] = - quant_params.fp4.fc2.weight_block_scale + - getOffsetWeightSF(local_expert, gemm2_n, gemm2_k, - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info2.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info2.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - + bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif - - assert(gemm_m <= INT32_MAX); - assert(gemm1_n > 0 && gemm1_n <= INT32_MAX); - assert(gemm1_k > 0 && gemm1_k <= INT32_MAX); - assert(gemm2_n > 0 && gemm2_n <= INT32_MAX); - assert(gemm2_k > 0 && gemm2_k <= INT32_MAX); - computeTmaWarpSpecializedInputStrides(layout_info1, gemm_m, gemm1_n, gemm1_k, expert); - computeTmaWarpSpecializedInputStrides(layout_info2, gemm_m, gemm2_n, gemm2_k, expert); - - if (is_active_expert) { - // Note: under low latency mode, we use the same input for all experts - // so for gemm1, the inputs are the same, - // for gemm2, we use the input generated by gemm1 - layout_info1.ptr_a[expert] = in1; - layout_info2.ptr_a[expert] = safe_inc_ptr(in2, expert * num_tokens * gemm2_k); - - // Each expert's weight matrix is a constant size NxK, get the matrix at index `expert` - layout_info1.ptr_b[expert] = safe_inc_ptr(weights1, local_expert * (gemm1_n * gemm2_k)); - layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k)); - - assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); - layout_info1.default_epilogue.ptr_d[expert] = - safe_inc_ptr(output1, expert * num_tokens * gemm1_n); - - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - // The output prior to this contains N elements per token, with `num_tokens` tokens - layout_info2.default_epilogue.ptr_d[expert] = - safe_inc_ptr(output2, expert * num_tokens * gemm2_n); - } - } else { - layout_info1.ptr_a[expert] = nullptr; - layout_info2.ptr_a[expert] = nullptr; - layout_info1.ptr_b[expert] = nullptr; - layout_info2.ptr_b[expert] = nullptr; - - layout_info1.default_epilogue.ptr_d[expert] = nullptr; - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info2.default_epilogue.ptr_d[expert] = nullptr; - } - } } // ========================== Permutation things ======================================= @@ -1426,7 +1353,7 @@ __global__ void expandInputRowsKernel( int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) { static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ, @@ -1536,7 +1463,7 @@ __global__ void expandInputRowsKernel( "Cannot use per-expert act scale for pre-quantized activations"); writeSF(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, - fc1_act_sf_flat, input_sf); + fc1_act_sf_flat, input_sf, swizzled_input_sf); dest_row_ptr[elem_index] = in_vec; } } @@ -1632,8 +1559,8 @@ void expandInputRowsKernelLauncher( int const k, int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, - bool enable_pdl, cudaStream_t stream) { + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + void const* prequant_scales, bool enable_pdl, cudaStream_t stream) { #ifdef ENABLE_FP4 TLLM_CHECK_WITH_INFO( (std::is_same_v && fc1_act_sf_flat) || @@ -1672,7 +1599,7 @@ void expandInputRowsKernelLauncher( // Could be either regular FP8 or MXFP8 else if constexpr (std::is_same_v && std::is_same_v) { - TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ"); + TLLM_CHECK_WITH_INFO(!prequant_scales, "FP8 is not supported for AWQ"); return quant_params.mxfp8_mxfp4.fc1.weight_block_scale ? &expandInputRowsKernel< InputActivationsType, ExpandedActivationsType, @@ -1714,21 +1641,22 @@ void expandInputRowsKernelLauncher( cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales, permuted_row_to_unpermuted_row, num_rows, hidden_size, k, quant_params.fp4.fc1.act_global_scale, use_per_expert_act_scale, - expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node, + expert_first_token_offset, fc1_act_sf_flat, input_sf, swizzled_input_sf, + num_experts_per_node, reinterpret_cast(prequant_scales)); } -#define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \ - template void expandInputRowsKernelLauncher( \ - InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, \ - float const* unpermuted_scales, float* permuted_scales, \ - int const* permuted_row_to_unpermuted_row, int64_t const num_rows, \ - int64_t const hidden_size, int const k, int const num_experts_per_node, \ - QuantParams const& quant_params, bool use_per_expert_act_scale, \ - int64_t* expert_first_token_offset, \ - TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \ - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, \ - bool enable_pdl, cudaStream_t stream) +#define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \ + template void expandInputRowsKernelLauncher( \ + InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, \ + float const* unpermuted_scales, float* permuted_scales, \ + int const* permuted_row_to_unpermuted_row, int64_t const num_rows, \ + int64_t const hidden_size, int const k, int const num_experts_per_node, \ + QuantParams const& quant_params, bool use_per_expert_act_scale, \ + int64_t* expert_first_token_offset, \ + TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \ + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, \ + void const* prequant_scales, bool enable_pdl, cudaStream_t stream) // Instantiate the data types that are used by the external pytorch op // INSTANTIATE_EXPAND_INPUT_ROWS(float, float); @@ -1751,22 +1679,24 @@ template ::value, sizeof_bits::value); - assert(orig_cols % FINALIZE_ELEM_PER_THREAD == 0); - int64_t const start_offset = threadIdx.x; int64_t const stride = FINALIZE_THREADS_PER_BLOCK; - int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; + int64_t const num_elems_in_padded_col = padded_cols / FINALIZE_ELEM_PER_THREAD; + int64_t const num_elems_in_orig_col = unpadded_cols / FINALIZE_ELEM_PER_THREAD; using BiasElem = cutlass::Array; using InputElem = cutlass::Array; @@ -1781,7 +1711,7 @@ __global__ void finalizeMoeRoutingKernel( #endif #pragma unroll - for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { + for (int elem_index = start_offset; elem_index < num_elems_in_orig_col; elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { @@ -1794,20 +1724,15 @@ __global__ void finalizeMoeRoutingKernel( int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; - int64_t expanded_rows = num_rows * experts_per_token; - if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { - continue; - } - float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; auto const* expanded_permuted_rows_row_ptr = - expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; + expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_padded_col; ComputeElem expert_result = arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); if (bias) { - auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; + auto const* bias_ptr = bias_v + expert_id * num_elems_in_padded_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); } @@ -1830,8 +1755,13 @@ __global__ void finalizeMoeRoutingNoFillingKernel( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* token_selected_experts, - int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const orig_cols, - int64_t const experts_per_token, int const num_experts_per_node, int const start_expert_id) { + int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const padded_cols, + int64_t const unpadded_cols, int64_t const experts_per_token, int const num_experts_per_node, + int const start_expert_id) { + assert(padded_cols % 4 == 0); + assert(unpadded_cols % 4 == 0); + assert(unpadded_cols <= padded_cols); + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif @@ -1860,17 +1790,16 @@ __global__ void finalizeMoeRoutingNoFillingKernel( continue; } - OutputType* reduced_row_ptr = reduced_unpermuted_output + source_row * orig_cols; + OutputType* reduced_row_ptr = reduced_unpermuted_output + source_row * unpadded_cols; // Load 128-bits per thread, according to the smallest data type we read/write constexpr int64_t FINALIZE_ELEM_PER_THREAD = 128 / std::min(sizeof_bits::value, sizeof_bits::value); - assert(orig_cols % FINALIZE_ELEM_PER_THREAD == 0); - int64_t const start_offset = threadIdx.x; int64_t const stride = FINALIZE_THREADS_PER_BLOCK; - int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; + int64_t const num_elems_in_padded_col = padded_cols / FINALIZE_ELEM_PER_THREAD; + int64_t const num_elems_in_orig_col = unpadded_cols / FINALIZE_ELEM_PER_THREAD; using BiasElem = cutlass::Array; using InputElem = cutlass::Array; @@ -1881,7 +1810,10 @@ __global__ void finalizeMoeRoutingNoFillingKernel( reinterpret_cast(expanded_permuted_rows); auto* reduced_row_ptr_v = reinterpret_cast(reduced_row_ptr); - for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { + for (int elem_index = start_offset; elem_index < num_elems_in_padded_col; + elem_index += stride) { + if (elem_index >= num_elems_in_orig_col) continue; // Skip writing beyond original columns + ComputeElem thread_output; thread_output.fill(0); for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { @@ -1893,22 +1825,17 @@ __global__ void finalizeMoeRoutingNoFillingKernel( int64_t const expanded_permuted_row_from_k_idx = unpermuted_row_to_permuted_row[source_row + k_idx * num_rows]; - int64_t valid_tokens = expert_first_token_offset[num_experts_per_node]; - if (expanded_permuted_row_from_k_idx < 0 || - expanded_permuted_row_from_k_idx >= valid_tokens) { - continue; - } float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; auto const* expanded_permuted_rows_row_ptr = - expanded_permuted_rows_v + expanded_permuted_row_from_k_idx * num_elems_in_col; + expanded_permuted_rows_v + expanded_permuted_row_from_k_idx * num_elems_in_padded_col; ComputeElem expert_result = arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); if (bias) { - auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; + auto const* bias_ptr = bias_v + expert_id * num_elems_in_padded_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); } @@ -1928,10 +1855,10 @@ void finalizeMoeRoutingKernelLauncher( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* final_scales, int const* unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* token_selected_experts, - int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, - int64_t const experts_per_token, int64_t const num_experts_per_node, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, - cudaStream_t stream) { + int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const padded_cols, + int64_t const unpadded_cols, int64_t const experts_per_token, + int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool enable_pdl, cudaStream_t stream) { // Only add bias on rank 0 for tensor parallelism bool const is_rank_0 = parallelism_config.tp_rank == 0; ScaleBiasType const* bias_ptr = is_rank_0 ? bias : nullptr; @@ -1962,8 +1889,8 @@ void finalizeMoeRoutingKernelLauncher( ScaleMode::NO_SCALE>; cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, - token_selected_experts, expert_first_token_offset, num_rows, cols, - experts_per_token, num_experts_per_node, start_expert_id); + token_selected_experts, expert_first_token_offset, num_rows, padded_cols, + unpadded_cols, experts_per_token, num_experts_per_node, start_expert_id); } else { // If all-gather reduce-scatter is used, finalizeMoeRouting must fill invalid output tokens with // zeros. @@ -1976,20 +1903,21 @@ void finalizeMoeRoutingKernelLauncher( : &finalizeMoeRoutingKernel; cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, - final_scales, unpermuted_row_to_permuted_row, token_selected_experts, cols, - experts_per_token, num_experts_per_node, start_expert_id); + final_scales, unpermuted_row_to_permuted_row, token_selected_experts, + padded_cols, unpadded_cols, experts_per_token, num_experts_per_node, + start_expert_id); } } -#define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \ - template void finalizeMoeRoutingKernelLauncher( \ - GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, \ - ScaleBiasT const* bias, float const* final_scales, \ - int const* expanded_source_row_to_expanded_dest_row, \ - int const* expanded_dest_row_to_expanded_source_row, int const* expert_for_source_row, \ - int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, \ - int64_t const experts_per_token, int64_t const num_experts_per_node, \ - MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, \ +#define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \ + template void finalizeMoeRoutingKernelLauncher( \ + GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, \ + ScaleBiasT const* bias, float const* final_scales, \ + int const* unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, \ + int const* expert_for_source_row, int64_t const* expert_first_token_offset, \ + int64_t const num_rows, int64_t const padded_cols, int64_t const actual_cols, \ + int64_t const experts_per_token, int64_t const num_experts_per_node, \ + MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, \ cudaStream_t stream); // // Instantiate the data types that are used by the external pytorch op @@ -2172,7 +2100,6 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - for (int64_t token = blockIdx.x; token < num_valid_tokens; token += gridDim.x) { size_t gemm_result_offset = token * inter_size * gated_size_mul; size_t output_offset = token * inter_size; @@ -2188,6 +2115,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; + gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f; gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f; gate_limit = activation_params.swiglu_limit ? activation_params.swiglu_limit[expert] @@ -2245,7 +2173,6 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, linear_value + arrayConvert(bias_ptr_vec[elem_index]); } return fn(fc1_value, linear_value); - } else { return fn(fc1_value); } @@ -2382,16 +2309,17 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 &doActivationKernel, decltype(block_scaling_type)::value> // Identity - }; return fn_list[static_cast(activation_type.activation_type)]; }; +#ifdef ENABLE_FP4 auto NVFP4 = tensorrt_llm::common::ConstExprWrapper< TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{}; auto MXFPX = tensorrt_llm::common::ConstExprWrapper< TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX>{}; +#endif auto NONE = tensorrt_llm::common::ConstExprWrapper< TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE>{}; @@ -2834,10 +2762,11 @@ void CutlassMoeFCRunnerepilogue_fusion_type == + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + permuted_token_final_scales_ = + gemm2_using_finalize_fusion ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr; bool const is_gated_activation = isGatedActivation(activation_type); bool const gemm1_using_fused_moe = moe_gemm_runner_.isFusedGatedActivation( @@ -2979,9 +2908,10 @@ void CutlassMoeFCRunner( static_cast(gemm_output), final_output, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, - token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, - num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream); + token_selected_experts, expert_first_token_offset, num_rows, hidden_size, + unpadded_hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall, + enable_pdl, stream); } template (use_ampere_activation_fusion ? output : intermediate_result), alpha_scale_ptr_array, /*occupancy*/ nullptr, - use_ampere_activation_fusion ? fc1_activation_type : ActivationType::Identity, + use_ampere_activation_fusion ? fc1_activation_type.activation_type + : ActivationType::Identity, expanded_num_rows, /*N*/ int64_t(fc1_out_size), /*K*/ hidden_size, @@ -3271,9 +3202,9 @@ void CutlassMoeFCRunner( static_cast(gemm_output), final_output, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, - token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, - num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream); + token_selected_experts, expert_first_token_offset, num_rows, hidden_size, + unpadded_hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall, + enable_pdl, stream); } else if (!using_tma_ws_gemm2) { finalizeMoeRoutingKernelLauncher( static_cast(gemm_output), final_output, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, - token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, - num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream); + token_selected_experts, expert_first_token_offset, num_rows, hidden_size, + unpadded_hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall, + enable_pdl, stream); } sync_check_cuda_error(stream); } @@ -3600,16 +3533,16 @@ void CutlassMoeFCRunner void CutlassMoeFCRunner::runMoe( - void const* input_activations_void, void const* input_sf_void, + void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights_void, void const* fc1_expert_biases_void, ActivationParams fc1_activation_type, void const* fc2_expert_weights_void, void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, - int64_t const hidden_size, int64_t const inter_size, int const full_num_experts, - int const experts_per_token, char* workspace_ptr, void* final_output_void, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, - bool const enable_alltoall, bool use_lora, LoraParams& lora_params, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, + int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, + int const full_num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output_void, int* unpermuted_row_to_permuted_row, + MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, + LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) { static constexpr bool int_scales_required = std::is_same::value || std::is_same::value || @@ -3664,6 +3597,27 @@ void CutlassMoeFCRunner::value)); } else { + // For NoSmem epilogue schedule, we need to align the output of the GEMM to 256 bits, for gated + // activation this is automatic if the usual alignment requirement is met + if (gemm1_config_->epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM && + !isGatedActivation(fc1_activation_type)) { + TLLM_CHECK_WITH_INFO( + inter_size % (256 / sizeof_bits::value) == 0, + "Inter size %d does not meet minimum alignment requirements for MOE GEMM %d", + (int)inter_size, (int)(256 / sizeof_bits::value)); + } + + if (gemm2_config_->epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM) { + TLLM_CHECK_WITH_INFO( + gemm2_config_->epilogue_fusion_type != + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE, + "Got NoSmem epilogue schedule, which is not supported for finalize fusion"); + TLLM_CHECK_WITH_INFO( + hidden_size % (256 / sizeof_bits::value) == 0, + "Hidden size %d does not meet minimum alignment requirements for MOE GEMM %d", + (int)hidden_size, (int)(256 / sizeof_bits::value)); + } + // Require at least 128 bits of alignment for MOE GEMM TLLM_CHECK_WITH_INFO( hidden_size % (128 / sizeof_bits::value) == 0, @@ -3755,10 +3709,11 @@ void CutlassMoeFCRunner:: TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, - UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, bool enable_pdl, + UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, + float const* router_scales, int const* permuted_row_to_unpermuted_row, bool enable_pdl, cudaStream_t stream) { // Always nullptr layout_info1.ptr_c = nullptr; @@ -3923,6 +3882,11 @@ CutlassMoeFCRunner:: layout_info2.ptr_c = nullptr; layout_info2.stride_c = nullptr; + layout_info1.fused_finalize_epilogue.ptr_bias = nullptr; + if (!bias2) { + layout_info2.fused_finalize_epilogue.ptr_bias = nullptr; + } + auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale : use_fp8 ? fp8_dequant1 @@ -3964,7 +3928,8 @@ CutlassMoeFCRunner:: layout_info2, num_tokens, expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, gemm1_in, gemm2_in, weights1, weights2, alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, - quant_params, bias1, bias2, gemm1_output, gemm2_output); + quant_params, bias1, bias2, gemm1_output, gemm2_output, router_scales, + permuted_row_to_unpermuted_row); return std::make_pair(layout_info1, layout_info2); } @@ -3985,55 +3950,7 @@ CutlassMoeFCRunner:: UnfusedGemmOutputType* output1, UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert, bool enable_pdl, cudaStream_t stream) { - TLLM_CHECK_WITH_INFO(!use_w4_groupwise, - "W4AFP8 and WFP4A16 are not supported in low latency mode"); - - // Always nullptr - layout_info1.ptr_c = nullptr; - layout_info1.stride_c = nullptr; - layout_info2.ptr_c = nullptr; - layout_info2.stride_c = nullptr; - - auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale - : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale - : fp8_dequant1; - auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale - : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.global_scale - : fp8_dequant2; - if (!alpha_scale_flat1) { - layout_info1.alpha_scale_ptr_array = nullptr; - } - if (!alpha_scale_flat2) { - layout_info2.alpha_scale_ptr_array = nullptr; - } - - layout_info1.int4_groupwise_params.enabled = false; - layout_info2.int4_groupwise_params.enabled = false; - layout_info1.int4_groupwise_params.use_wfp4a16 = false; - layout_info2.int4_groupwise_params.use_wfp4a16 = false; - - int const threads = std::min(1024, num_experts); - int const blocks = (num_experts + threads - 1) / threads; - - cudaLaunchConfig_t config; - config.gridDim = blocks; - config.blockDim = threads; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx( - &config, - computeStridesTmaWarpSpecializedLowLatencyKernel, - layout_info1, layout_info2, num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts, - input1, input2, weights1, weights2, alpha_scale_flat1, alpha_scale_flat2, fc1_fp4_act_flat, - fc2_fp4_act_flat, quant_params, bias1, bias2, output1, output2, num_active_experts_per, - active_expert_global_ids, start_expert); - - return std::make_pair(layout_info1, layout_info2); + TLLM_THROW("Min latency mode is no longer supported"); } template :: setupTmaWarpSpecializedInputs(int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, - int64_t inter_size, int64_t num_experts_per_node, - void const* input_activations_void, + int64_t unpadded_hidden_size, int64_t inter_size, + int64_t num_experts_per_node, void const* input_activations_void, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void* final_output, WeightType const* fc1_expert_weights, WeightType const* fc2_expert_weights, QuantParams quant_params, @@ -4081,6 +3998,8 @@ CutlassMoeFCRunner:: gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + gemm1_tma_ws_input.swap_ab = true; + gemm2_tma_ws_input.swap_ab = true; TLLM_CHECK_WITH_INFO(gemm1_input != gemm1_output, "Input and output buffers are overlapping"); return Self::computeStridesTmaWarpSpecializedLowLatency( @@ -4098,17 +4017,28 @@ CutlassMoeFCRunner:: gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + gemm1_tma_ws_input.swap_ab = gemm1_config_->swap_ab; + gemm2_tma_ws_input.swap_ab = gemm2_config_->swap_ab; + TLLM_CHECK_WITH_INFO( + (gemm1_tma_ws_input.swap_ab && gemm2_tma_ws_input.swap_ab) || !use_w4_groupwise, + "Hopper w4 mixed input groupwise requires swap_ab"); + bool apply_bias = parallelism_config.tp_rank == 0; - bool using_hopper_fused_finalize = !use_deterministic_hopper_reduce_ && - gemm2_config_->sm_version == 90 && !use_w4_groupwise && - !use_lora; - if (using_hopper_fused_finalize) { + auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr; + bool gemm2_using_finalize_fusion = + gemm2_config_->epilogue_fusion_type == + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + bool using_fused_finalize = + use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora; + TLLM_CHECK_WITH_INFO( + using_fused_finalize == gemm2_using_finalize_fusion, + "GEMM2 tactic requests finalize fusion, but the runner is not configured to use it"); + if (using_fused_finalize) { assert(min_latency_mode == false); + bool use_reduction = expanded_num_rows > num_rows; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - gemm2_tma_ws_input.setFinalizeFusionParams( - final_output, permuted_token_final_scales_, expert_first_token_offset_, - permuted_row_to_unpermuted_row_, apply_bias ? fc2_expert_biases : nullptr, hidden_size, - num_rows); + gemm2_tma_ws_input.setFinalizeFusionParams(final_output, unpadded_hidden_size, num_rows, + use_reduction); } // fp8_mxfp4 memsets the scaling factors to 1.0f @@ -4120,14 +4050,10 @@ CutlassMoeFCRunner:: "WFP4AFP8 expects the scaling factors to be aliased for gemm1 & gemm2"); TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF weight_block_scale_value_int{}; -#ifdef ENABLE_FP8 -#if CUDA_VERSION >= 12080 +#if defined(FLASHINFER_ENABLE_FP8_E8M0) && CUDART_VERSION >= 12080 __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(1.0f, __NV_SATFINITE, cudaRoundPosInf); std::memcpy(&weight_block_scale_value_int, &tmp, sizeof(tmp)); -#else - TLLM_CHECK_WITH_INFO(false, "WFP4AFP8 is not supported on CUDA "); -#endif #endif auto act_sf_rows = std::min(expanded_num_rows, num_rows * num_experts_per_node); @@ -4150,9 +4076,9 @@ CutlassMoeFCRunner:: reinterpret_cast(gemm1_input), reinterpret_cast(gemm2_input), fc1_expert_weights, fc2_expert_weights, quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2, fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, - fc1_expert_biases, fc2_expert_biases, - reinterpret_cast(gemm1_output), - reinterpret_cast(fc2_result_), enable_pdl, stream); + fc1_expert_biases, fc2_bias, reinterpret_cast(gemm1_output), + reinterpret_cast(fc2_result_), permuted_token_final_scales_, + permuted_row_to_unpermuted_row_, enable_pdl, stream); } } @@ -4412,7 +4338,7 @@ std::map> GemmProfilerBackend::getProfile if (is_tma_ws_input) { tma_ws_input_workspace_size = TmaWarpSpecializedGroupedGemmInput::workspaceSize(num_experts_per_node, mScalingType) * - (NUM_ROUTING_SAMPLES + 1); + (NUM_ROUTING_SAMPLES * NUM_FUSION_TYPES * NUM_SWAP_AB_TYPES + 1); if (is_w4afp8_quant || is_wfp4a16_quant) { quant_3_size = 0; @@ -4509,7 +4435,6 @@ std::map> GemmProfilerBackend::getProfile ADD(swiglu_alpha); ADD(swiglu_beta); ADD(swiglu_limit); - #undef ADD_NAME #undef ADD @@ -4640,13 +4565,32 @@ void GemmProfilerBackend::prepareQuantParams(int num_tokens, char* workspace_ptr } } -void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr_char, - void const* expert_weights, bool enable_pdl, - cudaStream_t stream) { +void GemmProfilerBackend::prepareTmaWsInputs( + int num_tokens, char* workspace_ptr_char, void const* expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, bool swap_ab, bool enable_pdl, + cudaStream_t stream) { if (mSM < 90) { return; } + bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); + bool use_wfp4a16 = + ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) && + mWType == nvinfer1::DataType::kUINT8); + bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; + bool const use_finalize_fusion = + fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; + bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || + use_w4_groupwise || + mGemmToProfile != GemmToProfile::GEMM_2; + if (use_finalize_fusion && finalize_fusion_not_supported) { + return; + } + + if (use_w4_groupwise && !swap_ab) { + return; + } + auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90); #define GET_WS_PTR(type, name) \ @@ -4684,11 +4628,19 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr dummy_tma_ws_input.enable_pdl = enable_pdl; // Set enable_pdl for dummy input tma_ws_input_workspace += tma_ws_size; + int workspace_index = + static_cast(use_finalize_fusion) * (NUM_SWAP_AB_TYPES * NUM_ROUTING_SAMPLES) + + static_cast(swap_ab) * NUM_ROUTING_SAMPLES; + tma_ws_input_workspace += workspace_index * tma_ws_size; + size_t num_expanded_tokens = num_tokens * mK; for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++) { - mTmaInputCache[i].configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, - workspaces.at("gemm_workspace").first, mScalingType); - mTmaInputCache[i].enable_pdl = enable_pdl; // Set enable_pdl for the profiler + // Note: Even though we have separate TMA WS inputs for finalize fusion on/off we reuse the same + // pointers to save space. + auto& cache_element = mTmaInputCache[use_finalize_fusion][swap_ab][i]; + cache_element.configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, + workspaces.at("gemm_workspace").first, mScalingType); + cache_element.enable_pdl = enable_pdl; // Set enable_pdl for cache element tma_ws_input_workspace += tma_ws_size; int64_t* expert_first_token_offset = @@ -4697,34 +4649,27 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr permuted_row_to_unpermuted_row_base + i * num_expanded_tokens; auto& gemm1_tma_ws_input = - mGemmToProfile == GemmToProfile::GEMM_1 ? mTmaInputCache[i] : dummy_tma_ws_input; + mGemmToProfile == GemmToProfile::GEMM_1 ? cache_element : dummy_tma_ws_input; auto& gemm2_tma_ws_input = - mGemmToProfile == GemmToProfile::GEMM_2 ? mTmaInputCache[i] : dummy_tma_ws_input; + mGemmToProfile == GemmToProfile::GEMM_2 ? cache_element : dummy_tma_ws_input; if (mSM >= 90) { + auto fc1_output_size = + isGatedActivation(mActivationType) ? mExpertInterSize * 2 : mExpertInterSize; + /* GEMM1 */ gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; - bool apply_bias = true; - bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); - bool use_wfp4a16 = - ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) && - mWType == nvinfer1::DataType::kUINT8); - bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; + gemm1_tma_ws_input.swap_ab = swap_ab; + gemm2_tma_ws_input.swap_ab = swap_ab; - bool using_fused_finalize = !mInterface->use_deterministic_hopper_reduce_ && mSM == 90 && - !mMinLatencyMode && !use_w4_groupwise; - if (using_fused_finalize) { + if (use_finalize_fusion) { assert(!mMinLatencyMode); gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - gemm2_tma_ws_input.setFinalizeFusionParams( - output, token_topk_unpermuted_scales, expert_first_token_offset, - permuted_row_to_unpermuted_row, apply_bias ? bias : nullptr, mExpertHiddenSize, - num_tokens); + gemm2_tma_ws_input.setFinalizeFusionParams(output, mExpertUnpaddedHiddenSize, num_tokens, + mK > 1); } - auto fc1_output_size = - isGatedActivation(mActivationType) ? mExpertInterSize * 2 : mExpertInterSize; if (mMinLatencyMode) { std::tie(gemm1_tma_ws_input, gemm2_tma_ws_input) = mInterface->computeStridesTmaWarpSpecializedLowLatencyDispatch( @@ -4742,7 +4687,7 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr mExpertInterSize, mNumExpertsPerNode, input, input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2, fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate, - enable_pdl, stream); + token_topk_unpermuted_scales, permuted_row_to_unpermuted_row, enable_pdl, stream); } sync_check_cuda_error(stream); } @@ -4752,7 +4697,6 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, void const* expert_weights, bool enable_pdl, cudaStream_t stream) { - mAllTacticsSaved = mInterface->getTactics(); mSampleIndex = 0; auto workspace_size = getWorkspaceSize(num_tokens); @@ -4760,7 +4704,13 @@ void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, prepareRouting(num_tokens, workspace_ptr_char, enable_pdl, stream); prepareQuantParams(num_tokens, workspace_ptr_char, stream); - prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, enable_pdl, stream); + for (auto fusion : {TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE}) { + for (auto swap_ab : {false, true}) { + prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, fusion, swap_ab, + enable_pdl, stream); + } + } } size_t GemmProfilerBackend::getWorkspaceSize(int maxM) { @@ -4827,54 +4777,95 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac TmaWarpSpecializedGroupedGemmInput tma_ws_input_template; if (tactic.is_tma_warp_specialized) { - tma_ws_input_template = mTmaInputCache[mSampleIndex]; + // Use non-finalize cache when finalize fusion is not supported for the current GEMM + bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); + bool use_wfp4a16 = + ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) && + mWType == nvinfer1::DataType::kUINT8); + bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; + bool finalize_supported_this_gemm = (mGemmToProfile == GemmToProfile::GEMM_2) && + mInterface->use_fused_finalize_ && !mMinLatencyMode && + !use_w4_groupwise; + bool request_finalize = tactic.epilogue_fusion_type == + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + bool use_finalize_index = request_finalize && finalize_supported_this_gemm; + + tma_ws_input_template = mTmaInputCache[use_finalize_index][tactic.swap_ab][mSampleIndex]; + TLLM_CHECK_WITH_INFO(tma_ws_input_template.isValid(), + "TMA WS input template is not initialized"); } mInterface->is_profiler = true; if (mGemmToProfile == GemmToProfile::GEMM_1) { - mInterface->gemm1( - input, // - output, // - intermediate, // - expert_first_token_offset, // - tma_ws_input_template, // - weights_sel, // - bias, // - expert_first_token_offset + num_experts_per_node, // - mQuantParams.wo.fc1_weight_scales, // - mQuantParams.fp8.dequant_fc1, // - mQuantParams.fp8_mxfp4.fc2.act_global_scale ? mQuantParams.fp8_mxfp4.fc2.act_global_scale - : mQuantParams.fp8.quant_fc2, // - fp4_act_scale_flat, // - fp4_act_scale_flat, // - mQuantParams, // - original_num_tokens, // - expanded_num_tokens, // - mExpertHiddenSize, // - mExpertInterSize, // - num_experts_per_node, // - ActivationParams(mActivationType, swiglu_alpha, swiglu_beta, swiglu_limit), // - alpha_scale_ptr_array, // - !mUseLora, // - /*use_deepseek_fp8_block_scale=*/false, // - stream, // - tactic, // - mMinLatencyMode, // - num_active_experts_per_node, // - active_expert_global_ids, // - enable_pdl); // + mInterface->gemm1(input, // + output, // + intermediate, // + expert_first_token_offset, // + tma_ws_input_template, // + weights_sel, // + bias, // + expert_first_token_offset + num_experts_per_node, // + mQuantParams.wo.fc1_weight_scales, // + mQuantParams.fp8.dequant_fc1, // + mQuantParams.fp8_mxfp4.fc2.act_global_scale + ? mQuantParams.fp8_mxfp4.fc2.act_global_scale + : mQuantParams.fp8.quant_fc2, // + fp4_act_scale_flat, // + fp4_act_scale_flat, // + mQuantParams, // + original_num_tokens, // + expanded_num_tokens, // + mExpertHiddenSize, // + mExpertInterSize, // + num_experts_per_node, // + ActivationParams(mActivationType, swiglu_alpha, swiglu_beta, swiglu_limit), + alpha_scale_ptr_array, // + !mUseLora, // + /*use_deepseek_fp8_block_scale=*/false, // + stream, // + tactic, // + mMinLatencyMode, // + num_active_experts_per_node, // + active_expert_global_ids, // + enable_pdl); // } else { TLLM_CHECK(mGemmToProfile == GemmToProfile::GEMM_2); - mInterface->gemm2( - input, intermediate, output, expert_first_token_offset, tma_ws_input_template, weights_sel, - bias, mQuantParams.wo.fc2_weight_scales, mQuantParams.fp8.dequant_fc2, fp4_act_scale_flat, - mQuantParams, token_topk_unpermuted_scales, token_topk_permuted_scales, - unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, - expert_first_token_offset + mNumExpertsPerNode, original_num_tokens, expanded_num_tokens, - mExpertHiddenSize, mExpertInterSize, num_experts_per_node, mK, alpha_scale_ptr_array, false, - nullptr, - /*use_deepseek_fp8_block_scale=*/false, stream, mParallelismConfig, mEnableAlltoall, tactic, - mMinLatencyMode, num_active_experts_per_node, active_expert_global_ids, enable_pdl); + mInterface->gemm2(input, // + intermediate, // + output, // + expert_first_token_offset, // + tma_ws_input_template, // + weights_sel, // + bias, // + mQuantParams.wo.fc2_weight_scales, // + mQuantParams.fp8.dequant_fc2, // + fp4_act_scale_flat, // + mQuantParams, // + token_topk_unpermuted_scales, // + token_topk_permuted_scales, // + unpermuted_row_to_permuted_row, // + permuted_row_to_unpermuted_row, // + token_selected_experts, // + expert_first_token_offset + mNumExpertsPerNode, // + original_num_tokens, // + expanded_num_tokens, // + mExpertHiddenSize, // + mExpertUnpaddedHiddenSize, // + mExpertInterSize, // + num_experts_per_node, // + mK, // + alpha_scale_ptr_array, // + false, // + nullptr, // + /*use_deepseek_fp8_block_scale=*/false, // + stream, // + mParallelismConfig, // + mEnableAlltoall, // + tactic, // + mMinLatencyMode, // + num_active_experts_per_node, // + active_expert_global_ids, // + enable_pdl); // } mInterface->is_profiler = false; diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu index 267e4591cc..8d996da98e 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu @@ -72,6 +72,8 @@ class DtypeUtils { default: TVM_FFI_ICHECK(false) << "unsupported data type"; } + + return nvinfer1::DataType::kFLOAT; // supress compiler warning } private: @@ -111,6 +113,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { TVM_FFI_ICHECK(false) << "Invalid output type " << DLDataTypeToString(output_type) << " specified for " << DLDataTypeToString(mActivationDtype); } + + return nullptr; // supress compiler warning }; FusedMoeRunner(DLDataType activation_dtype, DLDataType weight_dtype, DLDataType output_dtype, @@ -219,7 +223,13 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { } mProfiler = std::make_shared(); - mAllProfiles = mKernelRunner->getTactics(); + // Get tactics for both GEMM1 and GEMM2, combine them + auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1); + auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2); + mGemm1TacticCount = static_cast(gemm1_tactics.size()); + mGemm2TacticCount = static_cast(gemm2_tactics.size()); + mAllProfiles = gemm1_tactics; + mAllProfiles.insert(mAllProfiles.end(), gemm2_tactics.begin(), gemm2_tactics.end()); TVM_FFI_ICHECK(!mAllProfiles.empty()) << "No valid tactics available for fused moe op with the requested input combination " "Activation: " @@ -367,10 +377,14 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // TODO: support lora in the future ::tensorrt_llm::kernels::LoraParams lora_params{}; + // HACK Define default values for parameters we don't have good values for + bool const swizzled_input_sf = true; // Assume input_sf is swizzled by default + int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default + bool const use_lora = false; // No lora support yet #ifdef USING_OSS_CUTLASS_MOE_GEMM mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, - reinterpret_cast(token_selected_experts.data_ptr()), + swizzled_input_sf, reinterpret_cast(token_selected_experts.data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().data_ptr()) : nullptr, @@ -378,16 +392,16 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr, activation_params, fc2_expert_weights.data_ptr(), fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr, - quant_params, num_rows, hidden_size, inter_size, num_experts_total, + quant_params, num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total, static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, - false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, + use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); #else mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, - reinterpret_cast(token_selected_experts.data_ptr()), + swizzled_input_sf, reinterpret_cast(token_selected_experts.data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().data_ptr()) : nullptr, @@ -396,10 +410,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { activation_params, fc2_expert_weights.data_ptr(), fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr, quant_params, num_rows, hidden_size, inter_size, num_experts_total, - static_cast(experts_per_token), static_cast(workspace_info.workspace), - output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, - false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - enable_pdl, stream); + static_cast(experts_per_token), + static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), + static_cast(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, + mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); #endif } @@ -547,10 +561,14 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // TODO: support lora in the future ::tensorrt_llm::kernels::LoraParams lora_params{}; + // HACK Define default values for parameters we don't have good values for + bool const swizzled_input_sf_ml = true; // Assume input_sf is swizzled by default + int64_t const unpadded_hidden_size_ml = hidden_size; // Assume no padding by default + bool const use_lora_ml = false; // No lora support yet #ifdef USING_OSS_CUTLASS_MOE_GEMM mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, - reinterpret_cast(token_selected_experts.data_ptr()), + swizzled_input_sf_ml, reinterpret_cast(token_selected_experts.data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().data_ptr()) : nullptr, @@ -558,16 +576,16 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr, activation_params, fc2_expert_weights.data_ptr(), fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr, - quant_params, num_rows, hidden_size, inter_size, num_experts_total, + quant_params, num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total, static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, - false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, + use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); #else mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, - reinterpret_cast(token_selected_experts.data_ptr()), + swizzled_input_sf_ml, reinterpret_cast(token_selected_experts.data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().data_ptr()) : nullptr, @@ -575,11 +593,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr, activation_params, fc2_expert_weights.data_ptr(), fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr, - quant_params, num_rows, hidden_size, inter_size, num_experts_total, - static_cast(experts_per_token), static_cast(workspace_info.workspace), - output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, - false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - enable_pdl, stream); + quant_params, num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total, + static_cast(experts_per_token), + static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), + static_cast(workspace_info.src_to_dest_map), parallelism_config, false, use_lora_ml, + lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, + stream); #endif } @@ -641,19 +660,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { auto activation_dtype = (mUseW4GroupScaling && !isWFP4A16Quant()) ? dl_float8_e4m3fn : mActivationDtype; activation_dtype = isNvfp4Quant() ? dl_int64 : activation_dtype; + int64_t const unpadded_hidden_size_profiler = hidden_size; // HACK no padding by default #ifdef USING_OSS_CUTLASS_MOE_GEMM mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype), DtypeUtils::dataType(mOutputDtype), num_experts, static_cast(top_k), - hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA, - min_latency_mode, + hidden_size, unpadded_hidden_size_profiler, inter_size, group_size, + activation_type, USE_BIAS, USE_LORA, min_latency_mode, /*need_weights*/ false, parallelism_config, enable_alltoall); #else mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype), DtypeUtils::dataType(mOutputDtype), num_experts, static_cast(top_k), - hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA, - min_latency_mode, + hidden_size, unpadded_hidden_size_profiler, inter_size, group_size, + activation_type, USE_BIAS, USE_LORA, min_latency_mode, /*need_weights*/ false, parallelism_config); #endif @@ -691,6 +711,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { }); } else if (name == "get_tactic_num") { return Function::FromTyped([this]() -> int64_t { return getTacticNum(); }); + } else if (name == "get_gemm1_tactic_count") { + return Function::FromTyped([this]() -> int64_t { return mGemm1TacticCount; }); + } else if (name == "get_gemm2_tactic_count") { + return Function::FromTyped([this]() -> int64_t { return mGemm2TacticCount; }); } else if (name == "run_moe") { return Function::FromTyped( [this](TensorView output, TensorView input, TensorView token_selected_experts, @@ -758,6 +782,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; std::vector mAllProfiles; + int64_t mGemm1TacticCount{0}; + int64_t mGemm2TacticCount{0}; void setRunnerProfiles(Optional> profile_ids) { if (mUseDeepSeekFP8BlockScaling) { @@ -771,13 +797,34 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { } auto best_gemm1_profile = mAllProfiles.front(); - auto best_gemm2_profile = mAllProfiles.front(); + // Default GEMM2 profile should come from the GEMM2 subrange if present + auto best_gemm2_profile = + (mGemm2TacticCount > 0 && mAllProfiles.size() > static_cast(mGemm1TacticCount)) + ? mAllProfiles.at(mGemm1TacticCount) + : mAllProfiles.front(); if (profile_ids.has_value()) { TVM_FFI_ICHECK_EQ(profile_ids.value().size(), 2) << "Expecting 2 profile ids"; - best_gemm1_profile = profile_ids.value()[0] == -1 ? best_gemm1_profile - : mAllProfiles.at(profile_ids.value()[0]); - best_gemm2_profile = profile_ids.value()[1] == -1 ? best_gemm2_profile - : mAllProfiles.at(profile_ids.value()[1]); + // GEMM1 index: accept absolute index; otherwise if clearly out of combined range, keep + // default + auto id1 = profile_ids.value()[0]; + if (id1 != -1) { + TVM_FFI_ICHECK(id1 >= 0 && id1 < mGemm1TacticCount) << "Invalid gemm1 profile id: " << id1; + best_gemm1_profile = mAllProfiles.at(id1); + } + + // GEMM2 index: support both absolute (combined) and relative (within GEMM2 subrange) ids + auto id2 = profile_ids.value()[1]; + if (id2 != -1) { + int64_t absolute_id2 = id2; + // If id2 appears relative to GEMM2 subrange, offset it + if (id2 >= 0 && id2 < mGemm2TacticCount) { + absolute_id2 = mGemm1TacticCount + id2; + } + TVM_FFI_ICHECK(absolute_id2 >= 0 && + absolute_id2 < static_cast(mAllProfiles.size())) + << "Invalid gemm2 profile id: " << id2; + best_gemm2_profile = mAllProfiles.at(absolute_id2); + } } mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile); } diff --git a/csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h b/csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h index ccddbc1ef5..5f757f1b51 100644 --- a/csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h +++ b/csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h @@ -1181,6 +1181,9 @@ using Int = ConstExprWrapper; template using Bool = ConstExprWrapper; +template +using ConstBool = ConstExprWrapper; + template struct TmaDescType; diff --git a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp new file mode 100644 index 0000000000..c98f7ee3c1 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp @@ -0,0 +1,757 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass_extensions/arch/copy_red_global.hpp" +#include "cutlass_extensions/util/gather_tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +template < + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm90ScatterPtrArray { + + using SmemShape = decltype(make_shape(size(make_layout(get<0>(EpilogueTile{}))), size(make_layout(get<1>(EpilogueTile{}))))); + using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, SmemShape{})); + + using ElementIndex = int32_t; + + static constexpr bool MajorMode = cutlass::gemm::detail::is_major<0,StrideOutput>() ? 0 : 1; + + using StrideIndex = decltype(replace<1-MajorMode>(Stride<_0,_0,_0>{}, Int<1>{})); + + struct SharedStorage {}; + + struct Arguments { + ElementOutput* ptr_out{}; // output tensor pointer + StrideOutput dOut = {}; // output tensor stride + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + int shape_override = -1; // override value for contiguous output tensor mode + bool use_reduction = true; // use reduction or regular store + }; + + struct Params { + ElementOutput* ptr_out{}; // output tensor pointer + StrideOutput dOut = {}; // output tensor stride + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store + int shape_override = -1; // override value for contiguous output tensor mode + bool use_reduction = true; // use reduction or regular store + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return { + args.ptr_out, + args.dOut, + args.ptr_index, + cutlass::FastDivmod(args.index_modulo), + args.shape_override, + args.use_reduction + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray() { } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class ArgsTuple + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple) + : args_tuple(std::move(args_tuple)) {} + + ArgsTuple args_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rOut_frg = recast>(coalesce(tC_rOut)); // (EPI_V) + tC_rOut_frg(epi_v) = convert_input(frg_input); + + return tC_rOut_frg(epi_v); + } + + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + Tensor byte_buffer = recast(reduction_buffer); + static_assert(cosize(byte_buffer.layout()) * sizeof_bits_v >= cosize(SmemLayout{}) * sizeof_bits_v, + "Not enough space in scratch smem buffer"); + + Tensor sOut = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(recast_ptr(byte_buffer.data())), SmemLayout{})); + + auto thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_sOut_epi = thread_r2s.partition_D(sOut); + Tensor tRS_rOut_epi = thread_r2s.retile_S(tC_rOut); + + auto thread_r2g = tiled_r2g_red.get_slice(thread_idx); + Tensor tRG_gOut_epi = tRG_gOut(_,_,_,epi_m,epi_n); + Tensor tRG_sOut_epi = thread_r2g.partition_D(sOut); + Tensor tRG_rOut_epi = thread_r2g.retile_S(make_tensor(tC_rOut.data(), shape(tRG_sOut_epi))); // reuse D registers + + // sanity check for register reuse + CUTE_STATIC_ASSERT_V(cosize(tC_rOut.layout()) == cosize(tRG_rOut_epi.layout()), "Invalid register count for R2G"); + + copy(tiled_r2s, tRS_rOut_epi, tRS_sOut_epi); + sync_fn(); + copy(tRG_sOut_epi, tRG_rOut_epi); + + auto residue = residue_cD; // capturing structured bindings is a C++20 feature + Tensor tRG_cD_epi = tRG_cD(0,_,_,epi_m,epi_n); + auto pred = cute::lazy::transform(tRG_cD_epi, [&](auto c){ return elem_less(c, residue); }); + + if (use_reduction) { + copy_if(tiled_r2g_red, pred, tRG_rOut_epi, tRG_gOut_epi); + } + else { + copy_if(tiled_r2g_stg, pred, tRG_rOut_epi, tRG_gOut_epi); + } + } + }; + + template + static constexpr auto get_reduction_op() + { + using namespace cute; + + // For now only support red.add + if constexpr (is_same_v) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM70_RED_ADD_NOFTZ_F16x2{}; + } + else { + return SM70_RED_ADD_NOFTZ_F16{}; + } + } + else if constexpr (is_same_v) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2{}; + } + else { + return SM90_RED_ADD_NOFTZ_BF16{}; + } + } + else { + // non-vectorized atomic add for all other types until supported + return TypedAtomicAdd{}; + } + } + + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto index_read = [index = params_ptr->ptr_index[l], divmod = params_ptr->index_divmod](auto i){ return divmod.rem(index[i]); }; + Tensor mOut = cutlass::util::make_gather_tensor(params_ptr->ptr_out, make_shape(M,N,Int<1>{}), params_ptr->dOut, index_read); // (M,N,_1) + Tensor gOut = local_tile(mOut, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gOut_epi = flat_divide(gOut, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor mIdx = make_tensor(params_ptr->ptr_index[l], make_shape(M,N,Int<1>{}), StrideIndex{}); // (M,N,_1) + Tensor gIdx = local_tile(mIdx, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gIdx_epi = flat_divide(gIdx, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor cD_epi = flat_divide(args.cD, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor tC_gOut = sm90_partition_for_epilogue(gOut, args.epi_tile, args.tiled_copy, args.thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Tensor tC_rOut = make_tensor(take<0,3>(shape(tC_gOut))); // (CPY,CPY_M,CPY_N) + + auto tiled_r2s = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + + // Vectorization must not exceed alignment and also the number of values per thread in the tile + int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy)); + int constexpr NumValTile = product(take<0,2>(shape(cD_epi))); + int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads); + + // Choose the largest available red.global op and an st.global op with matching vectorization + using CopyOpR2GRed = decltype(get_reduction_op()); + using CopyOpR2GStg = UniversalCopy::NumValSrc * sizeof_bits_v>>; + + auto make_tiled_r2g = [&](auto copy_op) + { + using CopyAtomR2G = Copy_Atom; + constexpr int VecSize = CopyAtomR2G::NumValSrc; + if constexpr (cutlass::gemm::detail::is_k_major()) { + constexpr int ThreadsMajor = size<1>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout, Int>, Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major()) { + constexpr int ThreadsMajor = size<0>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout, Int>, Stride<_1, Int>>{}, + Layout, _1>>{}); + } + else { + static_assert(cute::is_void_v, "Unsupported D gmem layout."); + } + }; + + auto tiled_r2g_red = make_tiled_r2g(CopyOpR2GRed{}); + auto tiled_r2g_stg = make_tiled_r2g(CopyOpR2GStg{}); + + // Sanity checks - since we will be using one tiled copy with tensors partitioned with the other tiled copy, + // ensure they have matching layouts/tilers + using TiledR2GRed = decltype(tiled_r2g_red); + using TiledR2GStg = decltype(tiled_r2g_stg); + static_assert(typename TiledR2GRed::AtomLayoutSrc{} == typename TiledR2GStg::AtomLayoutSrc{}, "Mismatching AtomLayoutSrc"); + static_assert(typename TiledR2GRed::AtomLayoutDst{} == typename TiledR2GStg::AtomLayoutDst{}, "Mismatching AtomLayoutDst"); + static_assert(typename TiledR2GRed::TiledLayout_TV{} == typename TiledR2GStg::TiledLayout_TV{}, "Mismatching TiledLayout_TV"); + static_assert(typename TiledR2GRed::Tiler_MN{} == typename TiledR2GStg::Tiler_MN{}, "Mismatching Tiler_MN"); + + auto thread_r2g = tiled_r2g_red.get_slice(args.thread_idx); + Tensor tRG_gOut = thread_r2g.partition_D(gOut_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + Tensor tRG_cD = thread_r2g.partition_D(cD_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + + auto residue_cD = args.residue_cD; + + // If shape_override is set, adjust residue_cD to change predication. + // This is used to support fused slicing (where the output tensor is smaller than problem shape) + if (params_ptr->shape_override >= 0) { + get(residue_cD) += params_ptr->shape_override - get(args.problem_shape_mnkl); + } + + auto args_tuple = make_tuple( + cute::move(tC_rOut), + tiled_r2s, + tRG_gOut, + tRG_cD, + tiled_r2g_red, + tiled_r2g_stg, + params_ptr->use_reduction, + args.thread_idx, + residue_cD); + + return ConsumerStoreCallbacks(std::move(args_tuple)); + } +}; + +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBias + : ScaledAcc +{ + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerRowBiasSupported = true; +}; + +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerColBias + : ScaledAcc +{ + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerColBiasSupported = true; +}; + +template< + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBiasPerColScaleScatter + : ScaledAccPerRowBias +{ + using ElementAux = ElementOutput; + using GmemLayoutTagAux = GmemLayoutTagOut; + static constexpr int AlignmentAux = AlignmentOutput; + static constexpr bool IsAuxOutSupported = true; +}; + +template< + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerColBiasPerRowScaleScatter + : ScaledAccPerColBias +{ + using ElementAux = ElementOutput; + using GmemLayoutTagAux = GmemLayoutTagOut; + static constexpr int AlignmentAux = AlignmentOutput; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = alpha * acc + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPtrArray = + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + >; + +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerColBiasPtrArray = + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + >; + +template< + class CtaTileShapeMNK, + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray = + Sm90EVT, // scatter store + Sm90EVT, // scale * (alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_0,_1,int64_t>, 1>, // scale + Sm90ScaledAccPerRowBiasPtrArray // alpha * acc + bias + > + >; + +template< + class CtaTileShapeMNK, + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray = + Sm90EVT, // scatter store + Sm90EVT, // scale * (alpha * acc + bias) + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_1,_0,int64_t>, 1>, // scale + Sm90ScaledAccPerColBiasPtrArray // alpha * acc + bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScale, + class ElementScalar, + int AlignmentBias, + int AlignmentOutput, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::ScaledAccPerRowBiasPerColScaleScatter, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + > { + + using StrideOutput = cutlass::gemm::TagToStrideC_t; + + using Impl = Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + StrideOutput, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + >; + using Operation = fusion::ScaledAccPerRowBiasPerColScaleScatter< + GmemLayoutTagOut, + ElementOutput, + ElementCompute, + ElementBias, + ElementScale, + ElementScalar, + AlignmentBias, + AlignmentOutput, + RoundStyle>; + + struct Arguments { + + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + ElementScalar const* const* alpha_ptr_array{}; + StrideAlpha dAlpha{}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* const* bias_ptr{}; + StrideBias dBias{}; + + using StrideScale = Stride<_0,_1,int64_t>; + ElementScalar const* const* scale_ptr_array{}; + StrideScale dScale{}; + + // Nested args not usable due to a compiler bug with constexpr evaluation + // using ScatterArguments = typename Sm90ScatterPtrArray::Arguments; + // ScatterArguments scatter{}; + + ElementOutput* ptr_out{}; // output tensor pointer + StrideOutput dOut{}; // output tensor stride + int const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + int shape_override = -1; // override value for contiguous output tensor mode + bool use_reduction = true; // use reduction or regular store + + operator typename Impl::Arguments() const { + return + { // unary op: reduce(scale * (beta * C + (alpha * acc))) + { // binary op: scale * (beta * C + (alpha * acc)) + { scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end binary op + {} // binary args: multiply + }, // end binary op + //scatter // unary args: reduce + { ptr_out, dOut, ptr_index, index_modulo, shape_override, use_reduction } + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; + +}; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScale, + class ElementScalar, + int AlignmentBias, + int AlignmentOutput, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::ScaledAccPerColBiasPerRowScaleScatter, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + > { + + using StrideOutput = cutlass::gemm::TagToStrideC_t; + + using Impl = Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + StrideOutput, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + >; + using Operation = fusion::ScaledAccPerColBiasPerRowScaleScatter< + GmemLayoutTagOut, + ElementOutput, + ElementCompute, + ElementBias, + ElementScale, + ElementScalar, + AlignmentBias, + AlignmentOutput, + RoundStyle>; + + struct Arguments { + + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + ElementScalar const* const* alpha_ptr_array{}; + StrideAlpha dAlpha{}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* const* bias_ptr{}; + StrideBias dBias{}; + + using StrideScale = Stride<_1,_0,int64_t>; + ElementScalar const* const* scale_ptr_array{}; + StrideScale dScale{}; + + // Nested args not usable due to a compiler bug with constexpr evaluation + // using ScatterArguments = typename Sm90ScatterPtrArray::Arguments; + // ScatterArguments scatter{}; + + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + int const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + int shape_override = -1; // override value for contiguous output tensor mode + bool use_reduction = true; + + operator typename Impl::Arguments() const { + return + { // unary op: reduce(scale * (beta * C + (alpha * acc))) + { // binary op: scale * (beta * C + (alpha * acc)) + { scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end binary op + {} // binary args: multiply + }, // end binary op + //scatter // unary args: reduce + { ptr_out, dOut, ptr_index, index_modulo, shape_override, use_reduction } + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; + +}; + +} // namespace cutlass::epilogue::fusion + +// clang-format on diff --git a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index 6c5a823e8e..b2301c1a82 100644 --- a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,12 @@ #include #include #include +#include +#include #include "cute/tensor.hpp" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/tllmException.h" namespace tensorrt_llm { namespace cutlass_extensions { @@ -30,10 +34,10 @@ namespace cutlass_extensions { // in the kernel layout details when doing weight only quantization. enum class CutlassTileConfig { // Signals that we should run heuristics do choose a config - Undefined, + Undefined = 0, // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + ChooseWithHeuristic = 1, // SiMT config CtaShape128x128x8_WarpShape64x64x8, @@ -77,77 +81,96 @@ enum class SplitKStyle { // SPLIT_K_PARALLEL // Not supported yet }; -enum class CutlassTileConfigSM90 { +constexpr static int shape_tuple_to_enum(int m, int n, int k) { + assert(m >= 0 && n >= 0 && k >= 0); + assert(m < 1000 && n < 1000 && k < 1000); + return m * 1000000 + n * 1000 + k; +} + +template +constexpr static std::tuple enum_to_shape_tuple(TEnum shape_id_enum) { + static_assert(std::is_enum_v && std::is_same_v, int>, + "TEnum must be an enum with underlying type int"); + auto shape_id = static_cast(shape_id_enum); + assert(shape_id >= 0); + assert(shape_id < (int)1e9); + return std::make_tuple(shape_id / 1000000, (shape_id % 1000000) / 1000, shape_id % 1000); +} + +enum class CutlassTileConfigSM90 : int { // Signals that we should run heuristics do choose a config - Undefined, + Undefined = 0, // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + ChooseWithHeuristic = 1, // CTA configs for M=64 - CtaShape64x16x128B, - CtaShape64x32x128B, - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, + CtaShape64x16x128B = shape_tuple_to_enum(64, 16, 128), + CtaShape64x32x128B = shape_tuple_to_enum(64, 32, 128), + CtaShape64x64x128B = shape_tuple_to_enum(64, 64, 128), + CtaShape64x128x128B = shape_tuple_to_enum(64, 128, 128), + CtaShape64x256x128B = shape_tuple_to_enum(64, 256, 128), // CTA configs for M=128 - CtaShape128x16x128B, - CtaShape128x32x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, + CtaShape128x16x128B = shape_tuple_to_enum(128, 16, 128), + CtaShape128x32x128B = shape_tuple_to_enum(128, 32, 128), + CtaShape128x64x128B = shape_tuple_to_enum(128, 64, 128), + CtaShape128x128x128B = shape_tuple_to_enum(128, 128, 128), + CtaShape128x256x128B = shape_tuple_to_enum(128, 256, 128), // CTA configs for M=256 - CtaShape256x128x128B, - CtaShape256x256x128B, + CtaShape256x128x128B = shape_tuple_to_enum(256, 128, 128), + CtaShape256x256x128B = shape_tuple_to_enum(256, 256, 128), }; -enum class CutlassTileConfigSM100 { +enum class CutlassTileConfigSM100 : int { // Signals that we should run heuristics do choose a config - Undefined, + Undefined = 0, // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + ChooseWithHeuristic = 1, /* * Grouped GEMM */ // M=64 - CtaShape64x32x128B, - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, + CtaShape64x32x128B = shape_tuple_to_enum(64, 32, 128), + CtaShape64x64x128B = shape_tuple_to_enum(64, 64, 128), + CtaShape64x128x128B = shape_tuple_to_enum(64, 128, 128), + CtaShape64x256x128B = shape_tuple_to_enum(64, 256, 128), // M=128 - CtaShape128x8x256B, - CtaShape128x16x128B, - CtaShape128x32x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, - CtaShape128x128x256B, - CtaShape128x256x256B, + CtaShape128x8x256B = shape_tuple_to_enum(128, 8, 256), + CtaShape128x16x128B = shape_tuple_to_enum(128, 16, 128), + CtaShape128x32x128B = shape_tuple_to_enum(128, 32, 128), + CtaShape128x64x128B = shape_tuple_to_enum(128, 64, 128), + CtaShape128x128x128B = shape_tuple_to_enum(128, 128, 128), + CtaShape128x256x128B = shape_tuple_to_enum(128, 256, 128), + CtaShape128x128x256B = shape_tuple_to_enum(128, 128, 256), + CtaShape128x256x256B = shape_tuple_to_enum(128, 256, 256), // M=256 - CtaShape256x64x128B, - CtaShape256x128x128B, - CtaShape256x256x128B, + CtaShape256x64x128B = shape_tuple_to_enum(256, 64, 128), + CtaShape256x128x128B = shape_tuple_to_enum(256, 128, 128), + CtaShape256x256x128B = shape_tuple_to_enum(256, 256, 128), }; -enum class CutlassTileConfigSM120 { +// An alias to make the SHAPE_CASE macro work +using CutlassTileConfigSM103 = CutlassTileConfigSM100; + +enum class CutlassTileConfigSM120 : int { // Signals that we should run heuristics do choose a config - Undefined, + Undefined = 0, // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - CtaShape128x128x128B, - CtaShape128x128x64B, - CtaShape256x128x64B, - CtaShape128x256x64B, - CtaShape128x128x256B, - CtaShape256x128x128B, + ChooseWithHeuristic = 1, + + CtaShape128x128x128B = shape_tuple_to_enum(128, 128, 128), + CtaShape128x128x64B = shape_tuple_to_enum(128, 128, 64), + CtaShape256x128x64B = shape_tuple_to_enum(256, 128, 64), + CtaShape128x256x64B = shape_tuple_to_enum(128, 256, 64), + CtaShape128x128x256B = shape_tuple_to_enum(128, 128, 256), + CtaShape256x128x128B = shape_tuple_to_enum(256, 128, 128), }; enum class MainloopScheduleType { @@ -175,115 +198,73 @@ enum class EpilogueScheduleType { AUTO, // Automatically chooses an epilogue schedule compatible with the selected main loop // schedule for Hopper. For architectures older than hopper, the epilogue is always // performed by the same thread block as the main loop. + NO_SMEM, + TMA }; -enum class TileShape { - TileShape_64x16x128, - TileShape_64x32x128, - TileShape_64x64x128, - TileShape_64x128x128, - TileShape_64x256x128, - TileShape_64x512x128, - TileShape_128x16x128, - TileShape_128x32x128, - TileShape_128x64x128, - TileShape_128x128x128, - TileShape_128x256x128, - TileShape_256x128x128, - TileShape_256x256x128 +enum class TileShape : int { + Undefined = 0, + TileShape_64x16x128 = shape_tuple_to_enum(64, 16, 128), + TileShape_64x32x128 = shape_tuple_to_enum(64, 32, 128), + TileShape_64x64x128 = shape_tuple_to_enum(64, 64, 128), + TileShape_64x128x128 = shape_tuple_to_enum(64, 128, 128), + TileShape_64x256x128 = shape_tuple_to_enum(64, 256, 128), + TileShape_64x512x128 = shape_tuple_to_enum(64, 512, 128), + TileShape_128x16x128 = shape_tuple_to_enum(128, 16, 128), + TileShape_128x32x128 = shape_tuple_to_enum(128, 32, 128), + TileShape_128x64x128 = shape_tuple_to_enum(128, 64, 128), + TileShape_128x128x128 = shape_tuple_to_enum(128, 128, 128), + TileShape_128x256x128 = shape_tuple_to_enum(128, 256, 128), + TileShape_256x128x128 = shape_tuple_to_enum(256, 128, 128), + TileShape_256x256x128 = shape_tuple_to_enum(256, 256, 128) }; template constexpr auto get_tile_shape() { using namespace cute; - if constexpr (Shape_MNK == TileShape::TileShape_64x16x128) { - return cute::Shape<_64, _16, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x32x128) { - return cute::Shape<_64, _32, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x64x128) { - return cute::Shape<_64, _64, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x128x128) { - return cute::Shape<_64, _128, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x256x128) { - return cute::Shape<_64, _256, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x512x128) { - return cute::Shape<_64, _512, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x16x128) { - return cute::Shape<_128, _16, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x32x128) { - return cute::Shape<_128, _32, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x64x128) { - return cute::Shape<_128, _64, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x128x128) { - return cute::Shape<_128, _128, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x256x128) { - return cute::Shape<_128, _256, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_256x128x128) { - return cute::Shape<_256, _128, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_256x256x128) { - return cute::Shape<_256, _256, _128>{}; - } + static_assert(Shape_MNK != TileShape::Undefined, "TileShape is undefined"); + + constexpr auto shape_tuple = enum_to_shape_tuple(Shape_MNK); + return cute::Shape(shape_tuple)>, cute::Int(shape_tuple)>, + cute::Int(shape_tuple)>>{}; } -static auto get_tile_shape_name(TileShape Shape_MNK) { - if (Shape_MNK == TileShape::TileShape_64x16x128) { - return "64x16x128"; - } else if (Shape_MNK == TileShape::TileShape_64x32x128) { - return "64x32x128"; - } else if (Shape_MNK == TileShape::TileShape_64x64x128) { - return "64x64x128"; - } else if (Shape_MNK == TileShape::TileShape_64x128x128) { - return "64x128x128"; - } else if (Shape_MNK == TileShape::TileShape_64x256x128) { - return "64x256x128"; - } else if (Shape_MNK == TileShape::TileShape_64x512x128) { - return "64x512x128"; - } else if (Shape_MNK == TileShape::TileShape_128x16x128) { - return "128x16x128"; - } else if (Shape_MNK == TileShape::TileShape_128x32x128) { - return "128x32x128"; - } else if (Shape_MNK == TileShape::TileShape_128x64x128) { - return "128x64x128"; - } else if (Shape_MNK == TileShape::TileShape_128x128x128) { - return "128x128x128"; - } else if (Shape_MNK == TileShape::TileShape_128x256x128) { - return "128x256x128"; - } else if (Shape_MNK == TileShape::TileShape_256x128x128) { - return "256x128x128"; - } else if (Shape_MNK == TileShape::TileShape_256x256x128) { - return "256x256x128"; +template +static std::string get_tile_shape_name(TEnum Shape_MNK) { + static_assert(std::is_enum_v && std::is_same_v, int>, + "TEnum must be an enum with underlying type int"); + if ((int)Shape_MNK == 0) { + return "undefined"; + } else if ((int)Shape_MNK == 1) { + return "heuristic"; + } else { + auto [m, n, k] = enum_to_shape_tuple(Shape_MNK); + return std::to_string(m) + "x" + std::to_string(n) + "x" + std::to_string(k); } - return "Unknown shape"; } -enum class ClusterShape { - ClusterShape_1x1x1, - ClusterShape_2x1x1, - ClusterShape_1x2x1, - ClusterShape_2x2x1, - ClusterShape_1x4x1, - ClusterShape_4x2x1, - ClusterShape_2x4x1, - ClusterShape_4x4x1, - ClusterShape_1x8x1, - ClusterShape_8x1x1 +enum class ClusterShape : int { + Undefined = 0, + ClusterShape_1x1x1 = shape_tuple_to_enum(1, 1, 1), + ClusterShape_2x1x1 = shape_tuple_to_enum(2, 1, 1), + ClusterShape_1x2x1 = shape_tuple_to_enum(1, 2, 1), + ClusterShape_2x2x1 = shape_tuple_to_enum(2, 2, 1), + ClusterShape_1x4x1 = shape_tuple_to_enum(1, 4, 1), + ClusterShape_4x1x1 = shape_tuple_to_enum(4, 1, 1), + ClusterShape_4x2x1 = shape_tuple_to_enum(4, 2, 1), + ClusterShape_2x4x1 = shape_tuple_to_enum(2, 4, 1), + ClusterShape_4x4x1 = shape_tuple_to_enum(4, 4, 1), + ClusterShape_1x8x1 = shape_tuple_to_enum(1, 8, 1), + ClusterShape_8x1x1 = shape_tuple_to_enum(8, 1, 1) }; -static auto get_cluster_shape_name(ClusterShape Shape_MNK) { - if (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { - return "1x1x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { - return "2x1x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { - return "1x2x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { - return "2x2x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { - return "1x8x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { - return "8x1x1"; +static std::string get_cluster_shape_name(ClusterShape Shape_MNK) { + if (Shape_MNK == ClusterShape::Undefined) { + return "undefined"; + } else { + auto [m, n, k] = enum_to_shape_tuple(Shape_MNK); + return std::to_string(m) + "x" + std::to_string(n) + "x" + std::to_string(k); } - return "Unknown shape"; } template @@ -297,10 +278,22 @@ constexpr auto get_cluster_shape() { return cute::Shape<_1, _2, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { return cute::Shape<_2, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x1x1) { + return cute::Shape<_4, _1, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { return cute::Shape<_1, _8, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { return cute::Shape<_8, _1, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x4x1) { + return cute::Shape<_1, _4, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x2x1) { + return cute::Shape<_4, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x4x1) { + return cute::Shape<_2, _4, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x4x1) { + return cute::Shape<_4, _4, _1>{}; + } else { + return cute::Shape<_0, _0, _0>{}; } } @@ -314,7 +307,8 @@ struct CutlassGemmConfig { BLACKWELL = 1u << 4, GROUPED_GEMM = 1u << 5, FP8_ONLY = 1u << 6, - FP4_ONLY = 1u << 7 + FP4_ONLY = 1u << 7, + FP8FP4_MIXED = 1u << 8 }; CutlassTileConfig tile_config_sm80 = CutlassTileConfig::ChooseWithHeuristic; @@ -329,10 +323,17 @@ struct CutlassGemmConfig { MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + ClusterShape dynamic_cluster_shape = ClusterShape::Undefined; + ClusterShape fallback_cluster_shape = ClusterShape::Undefined; bool enableCudaKernel = false; int sm_version = 80; // Use 80 as a catch all for <90 bool is_tma_warp_specialized = false; + enum class EpilogueFusionType : int { NONE, FINALIZE }; + + EpilogueFusionType epilogue_fusion_type = EpilogueFusionType::NONE; + bool swap_ab = false; + CutlassGemmConfig() = default; CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, @@ -352,15 +353,24 @@ struct CutlassGemmConfig { sm_version(90), is_tma_warp_specialized(true) {} + // If dynamic_cluster_shape is provided, dynamic CGA will be enabled and cluster_shape will be + // interpreted as whether to use 1 or 2 SM mode, otherwise static cluster shape is used. CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule, - ClusterShape cluster_shape) + ClusterShape cluster_shape, + ClusterShape dynamic_cluster_shape = ClusterShape::Undefined, + ClusterShape fallback_cluster_shape = ClusterShape::Undefined, + int sm_version = 100) : tile_config_sm100(tile_config_sm100), mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape), - sm_version(100), - is_tma_warp_specialized(true) {} + dynamic_cluster_shape(dynamic_cluster_shape), + fallback_cluster_shape(fallback_cluster_shape), + sm_version(sm_version), + is_tma_warp_specialized(true) { + TLLM_CHECK_WITH_INFO(sm_version >= 100 && sm_version < 120, "Expected SM 10x version"); + } CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120, MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule, @@ -373,26 +383,38 @@ struct CutlassGemmConfig { is_tma_warp_specialized(true) {} int getTileConfigAsInt() const { - if (sm_version == 120) return (int)tile_config_sm120; - if (sm_version == 110) return (int)tile_config_sm100; - if (sm_version >= 100) return (int)tile_config_sm100; + if (sm_version == 120 || sm_version == 121) return (int)tile_config_sm120; + if (sm_version >= 100 && sm_version < 120) return (int)tile_config_sm100; if (sm_version == 90) return (int)tile_config_sm90; if (sm_version < 90) return (int)tile_config_sm80; assert(false && "Invalid SM version"); return -1; } + std::string getTileConfigAsName() const { + if (sm_version == 120 || sm_version == 121) return get_tile_shape_name(tile_config_sm120); + if (sm_version >= 100 && sm_version < 120) return get_tile_shape_name(tile_config_sm100); + if (sm_version == 90) return get_tile_shape_name(tile_config_sm90); + if (sm_version < 90) return std::to_string((int)tile_config_sm80); + assert(false && "Invalid SM version"); + return "invalid"; + } + std::string toString() const { std::stringstream tactic; tactic << "Cutlass GEMM Tactic"; if (is_tma_warp_specialized) { assert(sm_version >= 90 && "Invalid cutlass GEMM config"); tactic << "\n\tstyle=TMA Warp Specialized" - << "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt() - << "\n\tcluster shape ID: " << (int)cluster_shape + << "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsName() + << "\n\tcluster shape ID: " << get_cluster_shape_name(cluster_shape) + << "\n\tdynamic cluster shape ID: " << get_cluster_shape_name(dynamic_cluster_shape) + << "\n\tfallback cluster shape ID: " << get_cluster_shape_name(fallback_cluster_shape) << "\n\tmainloop sched: " << (int)mainloop_schedule << "\n\tepi sched: " << (int)epilogue_schedule - << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false") + << "\n\tepilogue fusion type: " << (int)epilogue_fusion_type + << "\n\tswap_ab: " << (swap_ab ? "true" : "false"); } else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { assert(sm_version < 90 && "Invalid cutlass GEMM config"); @@ -412,22 +434,26 @@ struct CutlassGemmConfig { inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) { // clang-format off - if (config.is_tma_warp_specialized) - { - out << "tile_config_sm90_enum: " << config.getTileConfigAsInt() - << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) - << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) - << ", cluster_shape_enum: " << int(config.cluster_shape) - << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); - } - else - { - out << "tile_config_enum: " << config.getTileConfigAsInt() - << ", split_k_style_enum: " << int(config.split_k_style) - << ", split_k_factor: " << config.split_k_factor - << ", stages: " << config.stages - << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); - } + if (config.is_tma_warp_specialized) + { + out << "tile_config_sm90_enum: " << config.getTileConfigAsInt() + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) + << ", cluster_shape_enum: " << int(config.cluster_shape) + << ", dynamic_cluster_shape_enum: " << int(config.dynamic_cluster_shape) + << ", fallback_cluster_shape_enum: " << int(config.fallback_cluster_shape) + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false") + << ", epilogue_fusion_type: " << int(config.epilogue_fusion_type) + << ", swap_ab: " << (config.swap_ab ? "true" : "false"); + } + else + { + out << "tile_config_enum: " << config.getTileConfigAsInt() + << ", split_k_style_enum: " << int(config.split_k_style) + << ", split_k_factor: " << config.split_k_factor + << ", stages: " << config.stages + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + } // clang-format on return out; } diff --git a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp index 4ba4fc9f20..5a3b5f2302 100644 --- a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp +++ b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp @@ -34,7 +34,7 @@ #include "cute/tensor.hpp" #include "cute/util/print.hpp" -using namespace cute; +namespace cutlass::util { /// Function object that applies an index to its argument template @@ -48,8 +48,8 @@ struct IndexedGather { CUTE_HOST_DEVICE friend void print(IndexedGather const& s) { cute::print("Indexed{"); - print(s.indices_); - print("}"); + cute::print(s.indices_); + cute::print("}"); } Iter indices_; @@ -73,23 +73,23 @@ struct CustomStride { CUTE_HOST_DEVICE friend void print(CustomStride const& s) { cute::print("Custom{"); - print(s.func_); + cute::print(s.func_); cute::print(","); - print(s.stride_); + cute::print(s.stride_); cute::print("}"); } template CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) { - return CustomStride(s.func_, - safe_div(s.stride_, div)); + return CustomStride( + s.func_, cute::safe_div(s.stride_, div)); } // Circumvent the requirement on make_layout that shape and stride are integral template CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride) { - return Layout(shape, stride); + return cute::Layout(shape, stride); } Func func_; @@ -98,6 +98,7 @@ struct CustomStride { template CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) { + using namespace cute; // Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather // stride auto idx = find_if(stride, [](auto x) { @@ -112,11 +113,13 @@ CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& template CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func) { + using namespace cute; Layout matrix_layout = make_identity_layout(shape); auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); } +} // namespace cutlass::util namespace cute { diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 8fc256ba31..34a90c65f9 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -158,7 +158,7 @@ std::vector get_candidate_tiles( CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; case CutlassGemmType::Fp8: if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { - if (sm == 89 || sm == 120) { + if (sm == 89 || sm >= 120) { return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, @@ -264,6 +264,119 @@ bool sm90_supports_mcast_along_n(CutlassTileConfigSM90 const tile) { #endif } +std::vector get_candidate_configs_sm100_dynamic_cluster_shape( + int sm, CutlassGemmConfig::CandidateConfigTypeParam const config, EpilogueScheduleType schedule, + ClusterShape const dynamic_cluster_shape, ClusterShape const fallback_cluster_shape) { + auto cluster1sm = ClusterShape::ClusterShape_1x1x1; + auto cluster2sm = ClusterShape::ClusterShape_2x1x1; + bool supports_2sm = dynamic_cluster_shape == ClusterShape::Undefined || + std::get<0>(enum_to_shape_tuple(dynamic_cluster_shape)) % 2 == 0; + + std::vector candidate_configs; + if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { + if (sm == 100) { + if (schedule != EpilogueScheduleType::TMA) return {}; + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, schedule, + cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + if (supports_2sm) { + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, schedule, + cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + } + } + + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, + schedule, cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, + schedule, cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + if (supports_2sm) { + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, schedule, + cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, schedule, + cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + } + return candidate_configs; + } + + std::vector> tile_configs{ + {CutlassTileConfigSM100::CtaShape128x128x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x256x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x32x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x64x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x32x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x128x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x256x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x64x128B, cluster1sm}, + }; + + if (supports_2sm) { + tile_configs.push_back({CutlassTileConfigSM100::CtaShape64x128x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape64x256x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape64x64x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x64x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x128x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x256x128B, cluster2sm}); + } + + if (config & CutlassGemmConfig::FP8_ONLY) { + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, cluster1sm}); + // TODO: re-enable when handled by the MoE GEMM dispatch + // tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, + // ClusterShape::ClusterShape_1x1x1 }); + } + + for (auto [tile, cluster] : tile_configs) { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, schedule, + cluster, dynamic_cluster_shape, fallback_cluster_shape, + sm}; + candidate_configs.push_back(config); + } + return candidate_configs; +} + +std::vector get_candidate_configs_sm100( + CutlassGemmConfig::CandidateConfigTypeParam const config, int sm) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM100 + return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::TMA, + ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined, + ClusterShape::Undefined, sm}}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + std::vector candidate_configs; + for (auto schedule : {EpilogueScheduleType::TMA, EpilogueScheduleType::NO_SMEM}) { + // TODO The tactic profiling is a bit long with all of these shapes enabled + // Shape 4x4x1 shapes do not seem to give better performance in the cases I tested so we + // disable it here + auto cluster_shapes = { + ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_4x1x1, + ClusterShape::ClusterShape_4x2x1 /*, ClusterShape::ClusterShape_4x4x1*/}; + for (auto cluster_shape : cluster_shapes) { + auto fallback_cluster_shape = cluster_shape == ClusterShape::ClusterShape_1x1x1 + ? ClusterShape::ClusterShape_1x1x1 + : ClusterShape::ClusterShape_2x1x1; + auto configs = get_candidate_configs_sm100_dynamic_cluster_shape( + sm, config, schedule, cluster_shape, fallback_cluster_shape); + candidate_configs.insert(candidate_configs.end(), configs.begin(), configs.end()); + } + + auto configs = get_candidate_configs_sm100_dynamic_cluster_shape( + sm, config, schedule, ClusterShape::Undefined, ClusterShape::Undefined); + candidate_configs.insert(candidate_configs.end(), configs.begin(), configs.end()); + } + return candidate_configs; + } else { + TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); + } +#endif +} + std::vector get_candidate_configs_sm90( CutlassGemmConfig::CandidateConfigTypeParam const config) { auto tiles = get_candidate_tiles_sm90(config); @@ -330,7 +443,7 @@ std::vector get_candidate_configs_sm90( return candidate_configs; } -std::vector get_candidate_configs_sm100( +/*std::vector get_candidate_configs_sm100( CutlassGemmConfig::CandidateConfigTypeParam const config) { #ifdef FAST_BUILD // Fast build disables all configs except this one for SM100 @@ -413,7 +526,7 @@ std::vector get_candidate_configs_sm100( TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); } #endif -} +}*/ std::vector get_candidate_configs_sm110( CutlassGemmConfig::CandidateConfigTypeParam const config) { @@ -538,7 +651,7 @@ std::vector get_candidate_configs( return get_candidate_configs_sm110(config_type_param); } if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { - return get_candidate_configs_sm100(config_type_param); + return get_candidate_configs_sm100(config_type_param, sm); } if (sm >= 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { return get_candidate_configs_sm120(config_type_param); diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h index f7ea83cdb0..80c024bdb7 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h @@ -24,7 +24,8 @@ namespace tensorrt_llm { namespace kernels { namespace cutlass_kernels { -template +template struct should_filter_tma_warp_specialized_gemm_problem_shape { #ifdef FAST_BUILD using SupportedCtaShape = @@ -32,15 +33,16 @@ struct should_filter_tma_warp_specialized_gemm_problem_shape { using SupportedCgaShape = cute::Shape; constexpr static bool value = !cute::is_same_v || - !cute::is_same_v; + !cute::is_same_v || DYNAMIC_CGA; #else constexpr static bool value = false; #endif }; -template +template constexpr static bool should_filter_tma_warp_specialized_gemm_problem_shape_v = should_filter_tma_warp_specialized_gemm_problem_shape::value; + DYNAMIC_CGA, ActivationType>::value; std::vector get_candidate_configs( int sm, int const max_split_k, diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 14ba601b39..2cc10e382b 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -46,6 +46,7 @@ namespace tkc = tensorrt_llm::cutlass_extensions; namespace tensorrt_llm { namespace kernels { namespace cutlass_kernels { +using namespace cute; template 2 && arch::kMinComputeCapability < 80) { // Multistage only supported on Ampere std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); - throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][filter_and_run_mixed_gemm] " + err_msg); } else if constexpr (Stages == 2 && arch::kMinComputeCapability >= 89) { // Multistage only supported on Ampere std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); - throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][filter_and_run_mixed_gemm] " + err_msg); } else if constexpr (cutlass::platform::is_same::value && arch::kMinComputeCapability < 89) { // FP8 activation type only supported on Ada+ GPUs std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8"; - throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][filter_and_run_mixed_gemm] " + err_msg); } else { generic_mixed_gemm_kernelLauncher() || is_fp8() || is_fp8() || is_fp8() || is_fp8(); @@ -362,17 +369,17 @@ void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, break; case tkc::CutlassTileConfig::Undefined: throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); + "[TensorRT LLM Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); break; case tkc::CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have " + "[TensorRT LLM Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have " "already been set by " "heuristic."); break; default: throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed " + "[TensorRT LLM Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed " "type GEMM."); break; } @@ -380,7 +387,7 @@ void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, // This is not a limitation in CUTLASS. We just do not need to support this case. std::string err_msg = "The activation type must equal the scale, bias and output types on Ampere and earlier."; - throw std::runtime_error("[TensorRT-LLm Error][dispatch_gemm_to_cutlass] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][dispatch_gemm_to_cutlass] " + err_msg); } } @@ -388,6 +395,7 @@ template CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); int device{-1}; tk::check_cuda_error(cudaGetDevice(&device)); sm_ = tk::getSMVersion(); @@ -398,7 +406,9 @@ CutlassFpAIntBGemmRunner CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() {} + OutputType>::~CutlassFpAIntBGemmRunner() { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); +} template @@ -414,6 +424,7 @@ void CutlassFpAIntBGemmRunner< tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); if (sm_ >= 75 && sm_ < 80) { dispatch_gemm_to_cutlass( @@ -429,7 +440,7 @@ void CutlassFpAIntBGemmRunner< ((__CUDACC_VER_MAJOR__ < 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) if constexpr (cutlass::platform::is_same::value) { throw std::runtime_error( - "[TensorRT-LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] INT4xFP8 GEMM for Ada " + "[TensorRT LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] INT4xFP8 GEMM for Ada " "needs " "CUDA>=12.4"); } @@ -442,13 +453,13 @@ void CutlassFpAIntBGemmRunner< static_assert(!cutlass::platform::is_same::value || cutlass::platform::is_same::value, "ScaleZeroType must be half for activation=fp8"); - sm90_dispatch_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_gemm_to_cutlass( A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); } else { throw std::runtime_error( - "[TensorRT-LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] Arch unsupported for " + "[TensorRT LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] Arch unsupported for " "CUTLASS mixed type " "GEMM"); } @@ -465,6 +476,7 @@ void CutlassFpAIntBGemmRunner( @@ -487,6 +499,7 @@ void CutlassFpAIntBGemmRunner((ActivationType const*)A, (WeightType const*)B, (ScaleZeroType const*)weight_scales, nullptr, nullptr, @@ -519,6 +534,7 @@ void CutlassFpAIntBGemmRunner::getWorkspaceSize(int const m, int const n, int const k) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); // For Hopper, we have to allocate large memory size in case for stream-K if (sm_ == 90) { // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L878-L892 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h index a81fffde9d..e01dbd279c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h @@ -26,7 +26,7 @@ namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels { +namespace cutlass_kernels_oss { namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; @@ -43,6 +43,7 @@ void sm90_dispatch_epilogue_schedules( ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); switch (gemm_config.epilogue_schedule) { case tkc::EpilogueScheduleType::AUTO: using EpilogueScheduleType = @@ -57,7 +58,7 @@ void sm90_dispatch_epilogue_schedules( break; default: throw std::runtime_error( - "[TensorRT-LLM Error][fpA_intB][sm90_dispatch_epilogue_schedules] epilogue schedule " + "[TensorRT LLM Error][fpA_intB][sm90_dispatch_epilogue_schedules] epilogue schedule " "config is invalid for " "mixed " "type GEMM."); @@ -105,6 +106,8 @@ void sm90_dispatch_mainloop_schedules( ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + constexpr bool tile_shapes_supported = are_tile_shapes_supported(); if constexpr (tile_shapes_supported) { @@ -122,7 +125,7 @@ void sm90_dispatch_mainloop_schedules( break; default: throw std::runtime_error( - "[TensorRT-LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] mainloop schedule " + "[TensorRT LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] mainloop schedule " "config is invalid " "for " "mixed type GEMM."); @@ -130,7 +133,7 @@ void sm90_dispatch_mainloop_schedules( } } else { throw std::runtime_error( - "[TensorRT-LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] Unsupported CTA and " + "[TensorRT LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] Unsupported CTA and " "Cluster shapes for " "mixed type GEMM."); } @@ -146,6 +149,7 @@ void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); switch (gemm_config.cluster_shape) { case tkc::ClusterShape::ClusterShape_1x1x1: sm90_dispatch_mainloop_schedules::type; if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v< - cutlass::arch::Sm90, CTAShape, ClusterShape, ActivationType>) { + cutlass::arch::Sm90, CTAShape, ClusterShape, false, ActivationType>) { using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; using CutlassScaleZeroType = typename TllmToCutlassTypeAdapter::type; @@ -192,7 +195,7 @@ void sm90_generic_mixed_gemm_kernelLauncher( int cta_shape_k = cute::size<2>(TileShape{}); if (group_size % cta_shape_k != 0) { std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k); - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner]" + err_msg); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner]" + err_msg); } if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { @@ -244,7 +247,7 @@ void sm90_generic_mixed_gemm_kernelLauncher( Gemm gemm; if (gemm.get_workspace_size(args) > workspace_bytes) { - TLLM_LOG_ERROR("[TensorRT-LLm Error][fpA_intB Runner] given workspace size insufficient."); + TLLM_LOG_ERROR("[TensorRT LLM Error][fpA_intB Runner] given workspace size insufficient."); } auto can_implement = gemm.can_implement(args); @@ -252,25 +255,25 @@ void sm90_generic_mixed_gemm_kernelLauncher( std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); std::cout << err_msg << std::endl; - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); } auto init_status = gemm.initialize(args, workspace, stream); if (init_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); } auto run_status = gemm.run(stream); if (run_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); } } else { std::stringstream ss; - ss << "[TensorRT-LLm Error][fpA_intB Runner] Config (" << (int64_t)cute::size<0>(CTAShape{}) + ss << "[TensorRT LLM Error][fpA_intB Runner] Config (" << (int64_t)cute::size<0>(CTAShape{}) << "," << (int64_t)cute::size<1>(CTAShape{}) << "," << (int64_t)cute::size<2>(CTAShape{}) << ") (" << (int64_t)cute::size<0>(ClusterShape{}) << "," << (int64_t)cute::size<1>(ClusterShape{}) << "," << (int64_t)cute::size<2>(ClusterShape{}) @@ -281,12 +284,12 @@ void sm90_generic_mixed_gemm_kernelLauncher( #else // COMPILE_HOPPER_TMA_GEMMS throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB Runner] Please recompile with support for hopper by passing " + "[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for hopper by passing " "90-real as an arch " "to build_wheel.py."); #endif // COMPILE_HOPPER_TMA_GEMMS } -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index b85decebcd..b77efbcac1 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -15,10 +15,10 @@ */ #pragma once -#include #include #include +#include #include #include "./common.h" @@ -32,17 +32,10 @@ #include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" #ifdef ENABLE_FP4 -#if CUDA_VERSION >= 12080 #include #endif -#endif namespace tensorrt_llm::kernels::cutlass_kernels { -template -constexpr auto transpose_stride(T const& t) { - return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), - cute::get<1>(t)); -} template struct GroupedGemmInput { @@ -71,8 +64,6 @@ struct GroupedGemmInput { }; struct TmaWarpSpecializedGroupedGemmInput { - template - using TransposeStride = decltype(transpose_stride(T{})); template using TransposeLayoutTag = std::conditional_t, @@ -83,14 +74,24 @@ struct TmaWarpSpecializedGroupedGemmInput { static_assert( std::is_same_v>); - // Layout for A and B is transposed and then swapped in the implementation - // This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM - using LayoutA = - TransposeLayoutTag; // Layout type for A matrix operand - using LayoutB = - TransposeLayoutTag; // Layout type for B matrix operand - using LayoutC = - TransposeLayoutTag; // Layout type for C matrix operand + // These are always the layout of A & B matrices, activations and weights will be assigned to + // either A or B based on swap_ab + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + + // When using Swap A&B we need to transpose the output matrix + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + using LayoutC_T = TransposeLayoutTag; + using LayoutD_T = TransposeLayoutTag; + + using StrideA = std::remove_pointer_t>; + using StrideB = std::remove_pointer_t>; + + using StrideC = std::remove_pointer_t>; + using StrideD = std::remove_pointer_t>; + using StrideC_T = std::remove_pointer_t>; + using StrideD_T = std::remove_pointer_t>; constexpr static int NVFP4BlockScaleVectorSize = 16; constexpr static int MXFPXBlockScaleVectorSize = 32; @@ -121,14 +122,6 @@ struct TmaWarpSpecializedGroupedGemmInput { return (dim + alignment - 1) / alignment * alignment; } - using StrideA = - std::remove_pointer_t>; // Use B because they will - // be swapped - using StrideB = - std::remove_pointer_t>; // Use A because they will - // be swapped - using StrideC = std::remove_pointer_t>; - #ifdef ENABLE_FP8 template constexpr static bool IsFP8_v = @@ -144,47 +137,40 @@ struct TmaWarpSpecializedGroupedGemmInput { using ProblemShape = cutlass::gemm::GroupProblemShape>; + bool swap_ab = false; ProblemShape shape_info{}; - StrideA* stride_a = nullptr; - StrideB* stride_b = nullptr; + void* stride_act = nullptr; + void* stride_weight = nullptr; - void const** ptr_a = nullptr; - void const** ptr_b = nullptr; + void const** ptr_act = nullptr; + void const** ptr_weight = nullptr; // C is currently the same in both epilogues - StrideC* stride_c = nullptr; + void* stride_c = nullptr; void const** ptr_c = nullptr; - struct DefaultEpilogue { - using LayoutD = - TransposeLayoutTag; // Layout type for D matrix operand - using StrideD = std::remove_pointer_t>; - - StrideD* stride_d = nullptr; - void** ptr_d = nullptr; - }; + // D is used in all cases except fused finalize + void* stride_d = nullptr; + void** ptr_d = nullptr; struct FusedFinalizeEpilogue { - using StrideFinalOutput = DefaultEpilogue::StrideD; - using StrideBias = TransposeStride>; - using StrideRouterScales = TransposeStride>; + using StrideFinalOutput_T = cutlass::detail::TagToStrideC_t; + using StrideFinalOutput = cutlass::detail::TagToStrideC_t; void* ptr_final_output = nullptr; + StrideFinalOutput_T stride_final_output_transposed{}; StrideFinalOutput stride_final_output{}; - void const* ptr_bias = nullptr; - StrideBias stride_bias{}; - - float const* ptr_router_scales = nullptr; - StrideRouterScales stride_router_scales{}; + void const** ptr_bias = nullptr; + float const** ptr_router_scales = nullptr; - int64_t const* ptr_expert_first_token_offset = nullptr; - int const* ptr_source_token_index = nullptr; + int const** ptr_source_token_index = nullptr; + int num_rows_in_final_output = 0; + int shape_override = -1; - size_t num_rows_in_final_output = 0; + bool use_reduction = true; }; - DefaultEpilogue default_epilogue; FusedFinalizeEpilogue fused_finalize_epilogue; enum class EpilogueFusion { NONE, ACTIVATION, GATED_ACTIVATION, FINALIZE }; @@ -195,11 +181,11 @@ struct TmaWarpSpecializedGroupedGemmInput { using ElementSF = uint8_t; using MXFPXElementSF = ElementSF; // Just an alias for now using NVFP4ElementSF = ElementSF; // Just an alias for now - ElementSF const** fpX_block_scaling_factors_A = nullptr; - ElementSF const** fpX_block_scaling_factors_B = nullptr; + ElementSF const** fpX_block_scaling_factors_act = nullptr; + ElementSF const** fpX_block_scaling_factors_weight = nullptr; - void* fpX_block_scaling_factors_stride_A = nullptr; - void* fpX_block_scaling_factors_stride_B = nullptr; + void* fpX_block_scaling_factors_stride_act = nullptr; + void* fpX_block_scaling_factors_stride_weight = nullptr; enum class FpXBlockScalingType { MXFPX, NVFP4, NONE }; FpXBlockScalingType fpX_block_scaling_type = FpXBlockScalingType::NONE; @@ -231,21 +217,19 @@ struct TmaWarpSpecializedGroupedGemmInput { size_t gemm_workspace_size = 0; // Whether to enable PDL (Programmatic Dependent Launch). - bool enable_pdl; + bool enable_pdl{}; - static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); + static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type); void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size, FpXBlockScalingType scaling_type); - bool isValid() const { return stride_a != nullptr && ptr_a != nullptr; } + bool isValid() const { return stride_act != nullptr && ptr_act != nullptr; } - void setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, - int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens); + void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, + bool use_reduction); std::string toString() const; }; @@ -275,7 +259,6 @@ class MoeGemmRunner { #else static constexpr bool use_wfp4a16 = false; #endif - #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v || std::is_same_v) && @@ -289,17 +272,16 @@ class MoeGemmRunner { #else static constexpr bool use_fp8 = false; static constexpr bool use_w4afp8 = false; - static constexpr bool use_wfp4afp4 = false; #endif static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) static constexpr bool use_fp4 = std::is_same_v; - static constexpr bool use_wfp4afp4 = + static constexpr bool use_wfp4afp8 = std::is_same_v && std::is_same_v; #else static constexpr bool use_fp4 = false; - static constexpr bool use_wfp4afp4 = false; + static constexpr bool use_wfp4afp8 = false; #endif void moeGemmBiasAct(GroupedGemmInput inputs, @@ -308,15 +290,19 @@ class MoeGemmRunner { void moeGemm(GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs); - std::vector getConfigs() const; - static std::vector getConfigs(int sm); - static std::vector getTmaWarpSpecializedConfigs(int sm); - static std::vector getBlackwellConfigs(int sm); - static std::vector getHopperConfigs(int sm); + std::vector getConfigs( + bool supports_finalize_fusion) const; + static std::vector getConfigs( + int sm, bool supports_finalize_fusion); + static std::vector getTmaWarpSpecializedConfigs( + int sm, bool supports_finalize_fusion); static std::vector getAmpereConfigs(int sm); [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; - [[nodiscard]] bool supportsTmaWarpSpecialized() const; + + [[nodiscard]] bool supportsTmaWarpSpecialized() const { return supportsTmaWarpSpecialized(sm_); } + + [[nodiscard]] static bool supportsTmaWarpSpecialized(int sm); [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 2ad23d7f7d..e278269b97 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "cutlass/gemm/gemm.h" #include "moe_gemm_kernels.h" #include "tensorrt_llm/common/assert.h" @@ -219,6 +221,8 @@ struct MOEParallelismConfig { } }; +enum class MoeGemmId : int { Undefined = 0, GEMM_1, GEMM_2 }; + struct QuantParams { // Int weight only quantization params struct { @@ -426,14 +430,15 @@ class CutlassMoeFCRunnerInterface { bool use_awq) = 0; virtual void setTactic(std::optional gemm1_config, std::optional gemm2_config) = 0; - virtual std::vector getTactics() = 0; + virtual std::vector getTactics(MoeGemmId gemm_id) = 0; virtual void runMoe(void const* input_activations, void const* input_sf, - int const* token_selected_experts, float const* token_final_scales, - void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationParams fc1_activation_type, void const* fc2_expert_weights, - void const* fc2_expert_biases, QuantParams quant_params, - int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, + bool const swizzled_input_sf, int const* token_selected_experts, + float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationParams fc1_activation_type, + void const* fc2_expert_weights, void const* fc2_expert_biases, + QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, @@ -459,26 +464,24 @@ class CutlassMoeFCRunnerInterface { int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl) = 0; - virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output, - int64_t const* const expert_first_token_offset, - TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template, - void const* const fc2_expert_weights, void const* const fc2_expert_biases, - void const* const fc2_int_scales, float const* const fc2_fp8_dequant, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, - QuantParams quant_params, float const* const token_topk_unpermuted_scales, - float const* const token_topk_permuted_scales, - int const* const unpermuted_row_to_permuted_row, - int const* permuted_row_to_unpermuted_row, - int const* const token_selected_experts, - int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, - int64_t const expanded_num_rows, int64_t const hidden_size, - int64_t const inter_size, int const num_experts_per_node, - int64_t const experts_per_token, float const** alpha_scale_ptr_array, - bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, - cudaStream_t stream, MOEParallelismConfig parallelism_config, - bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, - bool min_latency_mode, int* num_active_experts_per, - int* active_expert_global_ids, bool enable_pdl) = 0; + virtual void gemm2( + void const* const input, void* const gemm_output, void* const final_output, + int64_t const* const expert_first_token_offset, + TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template, + void const* const fc2_expert_weights, void const* const fc2_expert_biases, + void const* const fc2_int_scales, float const* const fc2_fp8_dequant, + TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, + QuantParams quant_params, float const* const token_topk_unpermuted_scales, + float const* const token_topk_permuted_scales, + int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, + int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, + int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, + int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, + void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream, + MOEParallelismConfig parallelism_config, bool const enable_alltoall, + cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, + int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl) = 0; virtual std::pair computeStridesTmaWarpSpecializedDispatch( @@ -490,7 +493,8 @@ class CutlassMoeFCRunnerInterface { float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, bool enable_pdl, + void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, + float const* router_scales, int const* permuted_row_to_unpermuted_row, bool enable_pdl, cudaStream_t stream) = 0; virtual std::pair @@ -509,13 +513,13 @@ class CutlassMoeFCRunnerInterface { virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; - bool use_deterministic_hopper_reduce_ = false; + bool use_fused_finalize_ = true; }; // Assumes inputs activations are row major. Weights need to be preprocessed by // th_op/weight_quantize.cc . Nested in a class to avoid multiple calls to cudaGetDeviceProperties // as this call can be expensive. Avoid making several duplicates of this class. -template || std::is_same_v) && !std::is_same_v; static constexpr bool use_w4afp8 = std::is_same_v && std::is_same_v; + static constexpr bool use_fp8_input = std::is_same_v; static_assert(!std::is_same_v, "Current logic requires backbone type to be >=16-bits"); static_assert(!std::is_same_v, @@ -601,25 +605,26 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { gemm2_config_ = std::move(gemm2_config); } - std::vector getTactics() override { - return moe_gemm_runner_.getConfigs(); + std::vector getTactics(MoeGemmId gemm_id) override { + return moe_gemm_runner_.getConfigs(gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused()); } - static std::vector getTactics(int sm) { + static std::vector getTactics(int sm, MoeGemmId gemm_id) { using RunnerType = decltype(moe_gemm_runner_); - return RunnerType::getConfigs(sm); + return RunnerType::getConfigs(sm, + gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm)); } - void runMoe(void const* input_activations, void const* input_sf, + void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts, - int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, - bool const enable_alltoall, bool use_lora, LoraParams& lora_params, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, + int64_t const hidden_size, int64_t const unpadded_hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, + char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row, + MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, + LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) override; @@ -663,11 +668,12 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, - int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, - float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, cudaStream_t stream, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, - cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, - int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl); + int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, + int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, + void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, + bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, + bool enable_pdl); // Overrides to allow us to forward on to the internal functions with the pointers using the // correct type @@ -710,7 +716,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, - int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, + int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream, @@ -727,10 +734,10 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { static_cast(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params, token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, - num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size, inter_size, - num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora, stream, - parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per, - active_expert_global_ids, enable_pdl); + num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size, unpadded_hidden_size, + inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, + fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode, + num_active_experts_per, active_expert_global_ids, enable_pdl); } virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override { @@ -747,7 +754,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, bool enable_pdl, + void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, + float const* router_scales, int const* permuted_row_to_unpermuted_row, bool enable_pdl, cudaStream_t stream) override { return Self::computeStridesTmaWarpSpecialized( expert_first_token_offset, layout_info1, layout_info2, num_tokens, expanded_num_tokens, @@ -758,7 +766,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { fp4_act_flat1, fp4_act_flat2, quant_params, reinterpret_cast(bias1), reinterpret_cast(bias2), reinterpret_cast(gemm1_output), - reinterpret_cast(gemm2_output), enable_pdl, stream); + reinterpret_cast(gemm2_output), router_scales, + permuted_row_to_unpermuted_row, enable_pdl, stream); } std::pair @@ -789,8 +798,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { std::pair setupTmaWarpSpecializedInputs(int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, - int64_t inter_size, int64_t num_experts_per_node, - void const* input_activations_void, + int64_t unpadded_hidden_size, int64_t inter_size, + int64_t num_experts_per_node, void const* input_activations_void, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void* final_output, WeightType const* fc1_expert_weights, WeightType const* fc2_expert_weights, QuantParams quant_params, @@ -811,7 +820,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, bool enable_pdl, cudaStream_t stream); + UnfusedGemmOutputType* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, bool enable_pdl, cudaStream_t stream); static std::pair computeStridesTmaWarpSpecializedLowLatency( TmaWarpSpecializedGroupedGemmInput layout_info1, @@ -844,8 +854,13 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { } bool mayHaveFinalizeFused() const { - return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90 && - !use_deterministic_hopper_reduce_ && !use_w4_groupwise; + return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && + use_fused_finalize_ && !use_w4_groupwise; + } + + static bool mayHaveFinalizeFused(int sm) { + using RunnerType = decltype(moe_gemm_runner_); + return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise; } // TODO: This should eventually take the quant params to give more flexibility @@ -891,7 +906,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int const* const unpermuted_row_to_permuted_row, int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, - int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, + int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params, bool enable_pdl, cudaStream_t stream); @@ -951,14 +967,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { struct GemmProfilerBackend { public: using Config = cutlass_extensions::CutlassGemmConfig; - enum class GemmToProfile { Undefined = 0, GEMM_1, GEMM_2 }; + using GemmToProfile = MoeGemmId; void init(CutlassMoeFCRunnerInterface& runner, GemmToProfile gemm_to_profile, nvinfer1::DataType dtype, nvinfer1::DataType wtype, nvinfer1::DataType otype, - int num_experts, int k, int64_t hidden_size, int64_t inter_size, int64_t group_size, - ActivationType activation_type, bool bias, bool use_lora, bool min_latency_mode, - bool need_weights, MOEParallelismConfig parallelism_config, - bool const enable_alltoall) { + int num_experts, int k, int64_t hidden_size, int64_t unpadded_hidden_size, + int64_t inter_size, int64_t group_size, ActivationType activation_type, bool bias, + bool use_lora, bool min_latency_mode, bool need_weights, + MOEParallelismConfig parallelism_config, bool const enable_alltoall) { mInterface = &runner; mGemmToProfile = gemm_to_profile; mDType = dtype; @@ -968,6 +984,7 @@ struct GemmProfilerBackend { mNumExpertsPerNode = num_experts / parallelism_config.ep_size; mK = k; mExpertHiddenSize = hidden_size; + mExpertUnpaddedHiddenSize = unpadded_hidden_size; mExpertInterSize = inter_size; // Already divided by tp_size mGroupSize = group_size; mActivationType = activation_type; @@ -1001,12 +1018,12 @@ struct GemmProfilerBackend { CutlassMoeFCRunnerInterface* mInterface; GemmToProfile mGemmToProfile = GemmToProfile::Undefined; - std::vector mAllTacticsSaved; int mSM{}; int64_t mNumExperts{}; int64_t mNumExpertsPerNode{}; int64_t mK{}; int64_t mExpertHiddenSize{}; + int64_t mExpertUnpaddedHiddenSize{}; int64_t mExpertInterSize{}; int64_t mGroupSize{}; ActivationType mActivationType{}; @@ -1022,7 +1039,11 @@ struct GemmProfilerBackend { // This will be a unique value for every iteration of warmup and actual bench constexpr static int64_t NUM_ROUTING_SAMPLES = 16; - std::array mTmaInputCache; + constexpr static int64_t NUM_FUSION_TYPES = 2; + constexpr static int64_t NUM_SWAP_AB_TYPES = 2; + constexpr static int64_t NUM_WORKSPACES = NUM_FUSION_TYPES * NUM_SWAP_AB_TYPES; + TmaWarpSpecializedGroupedGemmInput mTmaInputCache[NUM_FUSION_TYPES][NUM_SWAP_AB_TYPES] + [NUM_ROUTING_SAMPLES]; QuantParams mQuantParams; bool mBias{}; @@ -1036,6 +1057,7 @@ struct GemmProfilerBackend { void prepareRouting(int num_tokens, char* workspace, bool enable_pdl, cudaStream_t stream); void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream); void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, bool swap_ab, bool enable_pdl, cudaStream_t stream); }; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h index e701d72fe7..01f107d095 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h @@ -64,18 +64,18 @@ void expandInputRowsKernelLauncher( int const k, int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, - bool enable_pdl, cudaStream_t stream); + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + void const* prequant_scales, bool enable_pdl, cudaStream_t stream); template void finalizeMoeRoutingKernelLauncher( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* final_scales, int const* unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* token_selected_experts, - int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, - int64_t const experts_per_token, int64_t const num_experts_per_node, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, - cudaStream_t stream); + int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const padded_cols, + int64_t const unpadded_cols, int64_t const experts_per_token, + int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool enable_pdl, cudaStream_t stream); } // namespace cutlass_kernels } // namespace tensorrt_llm::kernels diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h index da4be7c179..9d493b8ef0 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { template void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl index 9a9ecafcd6..8ecc3fc18b 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl @@ -1,28 +1,30 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include -#include - -#include - #include "cute/tensor.hpp" #include "cutlass/array.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh" +#include "tensorrt_llm/common/cudaUtils.h" -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { template void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, @@ -93,4 +95,4 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int)(result)); } -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h index badc07b574..ae2ad222b3 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h @@ -18,16 +18,19 @@ #include -#include "moe_gemm_kernels.h" - -namespace tensorrt_llm::kernels::cutlass_kernels { +#include "../../include/moe_gemm_kernels.h" +namespace tensorrt_llm::kernels::cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; // Keep in sync with the signature generated by generate_kernels.py -template + typename ClusterShape, bool IsMXFPX, bool DYNAMIC_CGA, bool BIAS, bool SwapAB> void tma_warp_specialized_generic_moe_gemm_kernelLauncher( TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, int multi_processor_count, - cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size); + cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size, + cute::Shape dynamic_cluster_shape, + cute::Shape fallback_cluster_shape); -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index a3e4a87398..db5788bfdd 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +18,13 @@ #include #include +#include "../../include/moe_gemm_kernels.h" #include "../moe_tma_warp_specialized_traits.h" #include "cute/tensor.hpp" #include "cutlass/array.h" #include "cutlass/cutlass.h" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/fusion/operations.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -33,14 +33,7 @@ #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_ref.h" -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" -#include "moe_gemm_kernels.h" +#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp" #include "moe_gemm_tma_ws_launcher.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -58,7 +51,8 @@ namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels { +namespace cutlass_kernels_oss { +using namespace tensorrt_llm::kernels::cutlass_kernels; using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion; // Constructs an object with specific arguments only if flag is true @@ -76,8 +70,18 @@ ReturnType construct_if_true(Args&&... args) { template auto deduce_layout_sf() { if constexpr (FLAG && A) { + // In moe_kernels.cu we rely on these two types being the same. This is not necessarily + // guaranteed by cutlass so we have a sanity check here. + static_assert(std::is_same_v, + "Deduced layout SF does not match for A and B"); return typename GemmGrouped::GemmKernel::CollectiveMainloop::LayoutSFA{}; } else if constexpr (FLAG && !A) { + // In moe_kernels.cu we rely on these two types being the same. This is not necessarily + // guaranteed by cutlass so we have a sanity check here. + static_assert(std::is_same_v, + "Deduced layout SF does not match for A and B"); return typename GemmGrouped::GemmKernel::CollectiveMainloop::LayoutSFB{}; } else { return (void*)nullptr; @@ -85,18 +89,21 @@ auto deduce_layout_sf() { } template + typename EpilogueSchedule, typename EpilogueTag, EpilogueFusion FUSION, + typename TileShape, typename ClusterShape, bool IsMXFPX, bool DYNAMIC_CGA, bool BIAS, + bool SwapAB> struct DispatchToTmaWSFunction {}; // TMA WS specialized version template + typename EpilogueSchedule, typename EpilogueTag, EpilogueFusion FUSION, + typename TileShape, typename ClusterShape, bool IsMXFPX, bool DYNAMIC_CGA, bool BIAS, + bool SwapAB> void tma_warp_specialized_generic_moe_gemm_kernelLauncher( TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, - size_t* workspace_size) { + size_t* workspace_size, cute::Shape dynamic_cluster_shape, + cute::Shape fallback_cluster_shape) { if constexpr (ArchTag::kMinComputeCapability < 90) { TLLM_THROW("Invalid architecture instantiated"); } @@ -115,6 +122,14 @@ void tma_warp_specialized_generic_moe_gemm_kernelLauncher( "build_wheel.py."); } #endif +#ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS + else if constexpr (ArchTag::kMinComputeCapability == 103) { + // fallback sm100f logic is done in dispatchMoeGemmFinalDispatchTmaWarpSpecialized + TLLM_THROW( + "Please recompile with support for blackwell by passing 103-real as an arch to " + "build_wheel.py."); + } +#endif #ifndef COMPILE_BLACKWELL_SM120_TMA_GROUPED_GEMMS else if constexpr (ArchTag::kMinComputeCapability >= 120) { TLLM_THROW( @@ -123,10 +138,13 @@ void tma_warp_specialized_generic_moe_gemm_kernelLauncher( } #endif else { - return DispatchToTmaWSFunction::op(tma_ws_input, num_experts, multi_processor_count, - stream, kernel_occupancy, workspace_size); + return DispatchToTmaWSFunction::op(tma_ws_input, num_experts, + multi_processor_count, stream, + kernel_occupancy, workspace_size, + dynamic_cluster_shape, + fallback_cluster_shape); } } @@ -164,482 +182,553 @@ using SafeBF16 = __nv_bfloat16; using SafeBF16 = void; #endif +using namespace cutlass::epilogue; + // TODO Revert this back to a template instantiation once compiler bug is resolved -#define INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM(ArchTag_, DataType_, WeightType_, OutputType_, \ - EpilogueTag_, FUSION_, CTA_M_, CTA_N_, CTA_K_, \ - CGA_M_, CGA_N_, CGA_K_, MXFPX_, BIAS_) \ - static void \ - tma_warp_specialized_generic_moe_gemm_kernelLauncher_##ArchTag_##_##DataType_##_##WeightType_##_##OutputType_##_##EpilogueTag_##_##FUSION_##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##_##MXFPX_##_##BIAS_( \ - TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, \ - int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, \ - size_t* workspace_size) { \ - constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \ - /* constexpr static bool BIAS = BIAS_; */ /* Always false */ \ - using ArchTag = cutlass::arch::ArchTag_; \ - using T = DataType_; \ - using WeightType = WeightType_; \ - using OutputType = OutputType_; \ - using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \ - using TileShape = cute::Shape, cute::Int, cute::Int>; \ - using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ - constexpr static bool IsMXFPX = MXFPX_; \ - \ - if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && \ - ArchTag::kMinComputeCapability >= 90 && ArchTag::kMinComputeCapability < 100) { \ - TLLM_THROW( \ - "Please recompile with support for hopper by passing 90-real as an arch to " \ - "build_wheel.py."); \ - } else if constexpr (!COMPILE_BLACKWELL_TMA_GROUPED_GEMMS_ENABLED && \ - ArchTag::kMinComputeCapability >= 100 && \ - ArchTag::kMinComputeCapability < 120) { \ - TLLM_THROW( \ - "Please recompile with support for blackwell by passing 100-real as an arch to " \ - "build_wheel.py."); \ - } else if constexpr (!COMPILE_BLACKWELL_SM120_TMA_GROUPED_GEMMS_ENABLED && \ - ArchTag::kMinComputeCapability >= 120) { \ - TLLM_THROW( \ - "Please recompile with support for blackwell by passing 120-real as an arch to " \ - "build_wheel.py."); \ - } else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v< \ - ArchTag, TileShape, ClusterShape, T>) { \ - using namespace cute; \ - /* Helper class for defining all the cutlass types \ - // template \ - // struct TmaWarpSpecializedGroupedGemmInfo \ - { */ \ - using Arch = ArchTag; \ - constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \ - constexpr static bool IsSM120 = \ - Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \ - constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same::value && \ - cutlass::platform::is_same::value; \ - constexpr static bool IsFP4 = cutlass::platform::is_same::value; \ - static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by SM100"); \ - \ - constexpr static bool IsFP8 = cutlass::platform::is_same::value; \ - \ - /* TODO Update once mixed input support is added */ \ - static_assert(cutlass::platform::is_same::value || IsWFP4AFP8, \ - "TMA warp specialized MOE implementation does not support mixed input types"); \ - \ - constexpr static bool IsBlockScaled = IsFP4 || IsWFP4AFP8; \ - static_assert(!IsBlockScaled || IsBlackwell, "Block scaled is only implemented for SM100"); \ - \ - static_assert(cutlass::platform::is_same::value || \ - cutlass::platform::is_same::value || \ - cutlass::platform::is_same::value || IsFP8 || IsFP4, \ - "Specialized for bfloat16, half, float, fp8, fp4"); \ - \ - /* The cutlass type for the input elements. This is needed to convert to cutlass::half_t if \ - * necessary.*/ \ - using ElementType = typename TllmToCutlassTypeAdapter::type; \ - \ - /* TODO The below never trigger, and are incorrect for int8 types anyway \ - // using CutlassWeightTypeMaybeUint4 = typename \ - TllmToCutlassTypeAdapter::type; \ - // // For legacy reasons we convert unsigned 8-bit to signed \ - // using CutlassWeightTypeMaybeUint8 \ - // = std::conditional_t, cutlass::int4b_t, \ - // CutlassWeightTypeMaybeUint4>; \ - // using CutlassWeightType \ - // = std::conditional_t, int8_t, \ - // CutlassWeightTypeMaybeUint8>; */ \ - using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; \ - \ - using ElementA = ElementType; \ - using ElementB = CutlassWeightType; \ - \ - using ElementD = typename TllmToCutlassTypeAdapter< \ - TmaWarpSpecializedGroupedGemmInput::OutputTypeAdaptor_t>::type; \ - using ElementFinalOutput = typename TllmToCutlassTypeAdapter::type; \ - \ - /* using ElementC = std::conditional_t; */ \ - /* using ElementCSafe = std::conditional_t; */ \ - using ElementC = void; \ - using ElementCSafe = ElementD; \ - \ - using ElementAccumulator = float; \ - \ - using ElementBias = ElementFinalOutput; \ - using ElementRouterScales = float; \ - \ - using ElementSF = std::conditional_t< \ - IsMXFPX, cutlass::float_ue8m0_t, \ - cutlass::float_ue4m3_t>; /*TmaWarpSpecializedGroupedGemmInput::ElementSF;*/ \ - using ElementABlockScaled = std::conditional_t, \ - cute::tuple>; \ - using ElementBBlockScaled = std::conditional_t, \ - cute::tuple>; \ - \ - /* A matrix configuration - this is transposed and swapped with B */ \ - using LayoutA = TmaWarpSpecializedGroupedGemmInput::LayoutA; \ - constexpr static int AlignmentA = \ - 128 / \ - cutlass::sizeof_bits::value; /* Memory access granularity/alignment of A \ - matrix in units of elements (up to 16 bytes) */ \ - /* B matrix configuration - this is transposed and swapped with A */ \ - using LayoutB = \ - TmaWarpSpecializedGroupedGemmInput::LayoutB; /* Layout type for B matrix operand */ \ - constexpr static int AlignmentB = \ - IsWFP4AFP8 \ - ? 128 \ - : (128 / \ - cutlass::sizeof_bits::value); /* Memory access granularity/alignment of \ - B matrix in units \ - // of elements (up to 16 bytes)*/ \ - \ - /* C matrix configuration */ \ - using LayoutC = \ - TmaWarpSpecializedGroupedGemmInput::LayoutC; /* Layout type for C matrix operand */ \ - using StrideC = TmaWarpSpecializedGroupedGemmInput::StrideC; \ - /* Note we use ElementType here deliberately, so we don't break when BIAS is disabled */ \ - constexpr static int AlignmentC = \ - 128 / cutlass::sizeof_bits::value; /* Memory access granularity/alignment \ - of C matrix in \ - // units of elements (up to 16 bytes)*/ \ - \ - /* D matrix configuration */ \ - using LayoutD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::LayoutD; \ - using StrideD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD; \ - constexpr static int AlignmentD = \ - 128 / cutlass::sizeof_bits::value; /* Memory access granularity/alignment of D \ - matrix \ - // in units of elements (up to 16 bytes) */ \ - \ - static_assert( \ - cutlass::platform::is_same::value, \ - "TMA Warp Specialized Grouped GEMM specialisation doesn't support fused activation"); \ - \ - using EpilogueOp = \ - cutlass::epilogue::fusion::LinearCombination; \ - \ - /* TODO Add mode for fused activation once CUTLASS adds support \ - // using EpilogueSchedule = cutlass::platform::conditional_t< \ - // cutlass::platform::is_same::value, \ - // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, \ - // cutlass::epilogue::?????????????????? /// <<<<<< what supports \ - activations \ - // >;*/ \ - using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; \ - \ - constexpr static bool Is2SM = IsBlackwell && (cute::size<0>(ClusterShape{}) % 2) == 0; \ - using EpilogueScheduleSM100 = \ - std::conditional_t; \ - using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \ - using EpilogueScheduleBW = \ - std ::conditional_t; \ - using EpilogueSchedule = \ - std::conditional_t; \ - \ - using EpilogueTileShapeSm90 = TileShape; \ - using AtomClusterDiv = std::conditional_t; \ - using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape{})); \ - using EpilogueTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using EpilogueTileShape = \ - std::conditional_t; \ - using EpilogueElementC = std::conditional_t; \ - using EpilogueTensorOp = std::conditional_t; \ - using EpilogueSubTile = std::conditional_t< \ - Arch::kMinComputeCapability == 100 && IsFP4 && CTA_N_ == 256, /* SM100 Exactly */ \ - cute::Shape, cutlass::epilogue::collective::EpilogueTileAuto>; \ - /* Epilogue For Default Finalize */ \ - using CollectiveEpilogueDefault = typename cutlass::epilogue::collective:: \ - CollectiveBuilder< /**/ \ - Arch, EpilogueTensorOp, /**/ \ - EpilogueTileShape, ClusterShape, /**/ \ - EpilogueSubTile, /**/ \ - ElementAccumulator, ElementAccumulator, /**/ \ - EpilogueElementC, LayoutC*, AlignmentC, /**/ \ - ElementD, LayoutD*, AlignmentD, /**/ \ - EpilogueSchedule>::CollectiveOp; \ - \ - /* Epilogue For Fused Finalize */ \ - using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective:: \ - EpilogueMoeFusedFinalizeBuilder< /**/ \ - Arch, EpilogueTileShape, /**/ \ - ElementCSafe, StrideC*, /**/ \ - ElementFinalOutput, \ - TmaWarpSpecializedGroupedGemmInput:: \ - FusedFinalizeEpilogue::StrideFinalOutput, /**/ \ - ElementAccumulator, /**/ \ - ElementAccumulator, /**/ \ - ElementBias, \ - TmaWarpSpecializedGroupedGemmInput:: \ - FusedFinalizeEpilogue::StrideBias, /**/ \ - ElementRouterScales, \ - TmaWarpSpecializedGroupedGemmInput:: \ - FusedFinalizeEpilogue::StrideRouterScales /**/ \ - >::CollectiveOp; \ - \ - using CollectiveEpilogue = \ - std::conditional_t; \ - \ - using StageCountAutoCarveout = \ - cutlass::gemm::collective::StageCountAutoCarveout( \ - sizeof(typename CollectiveEpilogue::SharedStorage))>; \ - \ - using KernelScheduleSM90 = std::conditional_t< \ - IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, \ - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>; \ - \ - using KernelSchedule2SmSm100BlockScaled = \ - std::conditional_t; \ - using KernelSchedule1SmSm100BlockScaled = \ - std::conditional_t; \ - \ - /* TRT-LLM uses vector size 16 for block scaled */ \ - using KernelScheduleSM100 = std::conditional_t< \ - Is2SM, \ - std::conditional_t, \ - std::conditional_t>; \ - using KernelScheduleSM120 = cutlass ::gemm ::collective::KernelScheduleAuto; \ - using KernelScheduleBW = \ - std::conditional_t; \ - \ - using KernelSchedule = \ - std::conditional_t; \ - \ - using TensorOp = std::conditional_t; \ - \ - using MainloopElementA = \ - std::conditional_t; \ - using MainloopElementB = \ - std::conditional_t; \ - \ - using MainloopTileShapeSm90 = TileShape; \ - using MainloopTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using MainloopTileShape = \ - std::conditional_t; \ - \ - using CollectiveMainloop = typename cutlass::gemm::collective:: \ - CollectiveBuilder< /**/ \ - Arch, TensorOp, /**/ \ - MainloopElementB, LayoutB*, AlignmentB, /* A & B swapped here */ \ - MainloopElementA, LayoutA*, AlignmentA, /**/ \ - ElementAccumulator, /**/ \ - MainloopTileShape, ClusterShape, /**/ \ - StageCountAutoCarveout, KernelSchedule>::CollectiveOp; \ - \ - using GemmKernel = \ - cutlass::gemm::kernel::GemmUniversal; \ - \ - using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter; \ - /*}; \ - \ \ - // using namespace cute; \ - // using GemmInfo = TmaWarpSpecializedGroupedGemmInfo;; \ - // \ - // using ElementAccumulator = typename GemmInfo::ElementAccumulator; \ - // using ElementA = typename GemmInfo::ElementA; \ - // using ElementB = typename GemmInfo::ElementB; \ - // using ElementC = typename GemmInfo::ElementC; \ - // using ElementCSafe = typename GemmInfo::ElementCSafe; \ - // using ElementD = typename GemmInfo::ElementD; \ - // using ElementFinalOutput = typename GemmInfo::ElementFinalOutput; \ - // using ElementBias = typename GemmInfo::ElementBias; \ - // \ - // using CollectiveMainloop = typename GemmInfo::CollectiveMainloop; \ - // using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue; \ - // using GemmKernel = typename GemmInfo::GemmKernel; \ - // using GemmGrouped = typename GemmInfo::GemmGrouped;*/ \ - \ - if (kernel_occupancy != nullptr) { \ - TLLM_THROW("TMA WS kernels do not support calculating occupancy"); \ - return; \ - } \ - \ - cutlass::KernelHardwareInfo hw_info; \ - hw_info.device_id = 0; \ - hw_info.sm_count = multi_processor_count; \ - \ - GemmGrouped gemm; \ - \ - if (workspace_size != nullptr) { \ - /* Make a mock problem shape with just the minimal information actually required to get \ - the workspace \ - // size This makes some assumptions about CUTLASS's implementation which is suboptimal. We \ - have a check \ - // later to catch future cutlass updates causing silent breakages, but that is not fool \ - proof. The \ - // alternative is to wait until we have data and then dynamically allocate the workspace*/ \ - typename TmaWarpSpecializedGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, \ - nullptr}; \ - \ - typename GemmKernel::TileScheduler::Arguments scheduler_args{ \ - 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ - const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ - shape_info, \ - {}, \ - {}, \ - hw_info, \ - scheduler_args}; \ - *workspace_size = gemm.get_workspace_size(args); \ - return; \ - } \ - \ - using MainloopArguments = typename CollectiveMainloop::Arguments; \ - TLLM_CHECK(tma_ws_input.stride_a); \ - TLLM_CHECK(tma_ws_input.stride_b); \ - TLLM_CHECK(tma_ws_input.ptr_a); \ - TLLM_CHECK(tma_ws_input.ptr_b); \ - \ - auto make_mainloop_params = [&]() -> MainloopArguments { \ - if constexpr (IsBlockScaled) { \ - return construct_if_true( \ - reinterpret_cast(tma_ws_input.ptr_b), tma_ws_input.stride_b, \ - reinterpret_cast(tma_ws_input.ptr_a), tma_ws_input.stride_a, \ - reinterpret_cast(tma_ws_input.fpX_block_scaling_factors_B), \ - reinterpret_cast())>( \ - tma_ws_input.fpX_block_scaling_factors_stride_B), \ - reinterpret_cast(tma_ws_input.fpX_block_scaling_factors_A), \ - reinterpret_cast())>( \ - tma_ws_input.fpX_block_scaling_factors_stride_A)); \ - } else { \ - return construct_if_true( \ - reinterpret_cast(tma_ws_input.ptr_b), tma_ws_input.stride_b, \ - reinterpret_cast(tma_ws_input.ptr_a), tma_ws_input.stride_a); \ - } \ - }; \ - \ - auto const mainloop_params = make_mainloop_params(); \ - \ - using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ - using EpilogueScalars = decltype(EpilogueArguments{}.thread); \ - auto make_epilogue_scalars = [&]() { \ - if constexpr (IsBlackwell) { \ - return construct_if_true( \ - ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f), nullptr, \ - nullptr, tma_ws_input.alpha_scale_ptr_array, nullptr, \ - cute::Shape<_0, _0, int64_t>{ \ - cute::_0{}, cute::_0{}, \ - (tma_ws_input.alpha_scale_ptr_array != nullptr) ? 1 : 0}, \ - cute::Shape<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 0}); \ - } else if (tma_ws_input.alpha_scale_ptr_array) { \ - return construct_if_true( \ - tma_ws_input.alpha_scale_ptr_array); \ - } else { \ - return construct_if_true( \ - ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); \ - } \ - }; \ - auto epilogue_scalars = make_epilogue_scalars(); \ - /* TODO ptr_c casts to ElementCSafe** because there is a workaround in CUTLASS */ \ - auto make_epi_args = [&]() { \ - static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \ - "Unimplemented fusion provided to TMA WS MoE gemm launcher"); \ - \ - if constexpr (FUSION == EpilogueFusion::NONE) { \ - auto epi_params = tma_ws_input.default_epilogue; \ - return construct_if_true < FUSION == EpilogueFusion::NONE, \ - EpilogueArguments > (epilogue_scalars, nullptr, tma_ws_input.stride_c, \ - reinterpret_cast(epi_params.ptr_d), \ - epi_params.stride_d); \ - } else if constexpr (FUSION == EpilogueFusion::FINALIZE) { \ - /* Parameters for fused finalize */ \ - auto epi_params = tma_ws_input.fused_finalize_epilogue; \ - return construct_if_true < FUSION == EpilogueFusion::FINALIZE, \ - EpilogueArguments > \ - (epilogue_scalars, /* Parameters to underlying epilogue */ \ - nullptr, tma_ws_input.stride_c, /* C params */ \ - reinterpret_cast(epi_params.ptr_final_output), \ - epi_params.stride_final_output, /* D (output) params */ \ - reinterpret_cast(epi_params.ptr_bias), \ - epi_params.stride_bias, /* Bias params */ \ - epi_params.ptr_router_scales, \ - epi_params.stride_router_scales, /* Router scales */ \ - epi_params.ptr_expert_first_token_offset, /* Offset of this expert's token \ - in the router scales */ \ - epi_params \ - .ptr_source_token_index, /* Index of the source token to sum into */ \ - epi_params \ - .num_rows_in_final_output /* Number of tokens in the output buffer */ \ - ); \ - } \ - }; \ - EpilogueArguments const epilogue_params = make_epi_args(); \ - /* EpilogueArguments const epilogue_params = make_epi_args( \ - // tma_ws_input, epilogue_scalars \ - // );*/ \ - \ - typename GemmKernel::TileScheduler::Arguments scheduler_args{ \ - 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ - \ - const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ - tma_ws_input.shape_info, \ - mainloop_params, \ - epilogue_params, \ - hw_info, \ - scheduler_args}; \ - \ - size_t calculated_ws_size = gemm.get_workspace_size(args); \ - TLLM_CHECK_WITH_INFO(calculated_ws_size <= tma_ws_input.gemm_workspace_size, \ - "Workspace is size %zu but only %zu were allocated", \ - calculated_ws_size, tma_ws_input.gemm_workspace_size); \ - \ - auto can_implement = gemm.can_implement(args); \ - TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, \ - "Grouped GEMM kernel will fail for params. Error: " + \ - std::string(cutlass::cutlassGetStatusString(can_implement))); \ - \ - auto init_status = gemm.initialize(args, tma_ws_input.gemm_workspace); \ - TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, \ - "Failed to initialize cutlass TMA WS grouped gemm. Error: " + \ - std::string(cutlass::cutlassGetStatusString(init_status))); \ - auto run_status = gemm.run(stream, nullptr, tma_ws_input.enable_pdl); \ - TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, \ - "Failed to run cutlass TMA WS grouped gemm. Error: " + \ - std::string(cutlass::cutlassGetStatusString(run_status))); \ - sync_check_cuda_error(stream); \ - } else { \ - TLLM_THROW("Configuration was disabled by FAST_BUILD"); \ - } \ - \ - return; \ - } \ - \ - template <> \ - struct DispatchToTmaWSFunction< \ - cutlass::arch::ArchTag_, DataType_, WeightType_, OutputType_, \ - tensorrt_llm::cutlass_extensions::EpilogueTag_, EpilogueFusion::FUSION_, \ - cute::Shape, cute::Int, cute::Int>, \ - cute::Shape, cute::Int, cute::Int>, MXFPX_, BIAS_> { \ - constexpr static auto* op = \ - &tma_warp_specialized_generic_moe_gemm_kernelLauncher_##ArchTag_##_##DataType_##_##WeightType_##_##OutputType_##_##EpilogueTag_##_##FUSION_##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##_##MXFPX_##_##BIAS_; \ - }; \ - template void tma_warp_specialized_generic_moe_gemm_kernelLauncher< \ - cutlass::arch::ArchTag_, DataType_, WeightType_, OutputType_, \ - tensorrt_llm::cutlass_extensions::EpilogueTag_, EpilogueFusion::FUSION_, \ - cute::Shape, cute::Int, cute::Int>, \ - cute::Shape, cute::Int, cute::Int>, MXFPX_, BIAS_>( \ - TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, \ - int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, \ - size_t* workspace_size); +#define INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM( \ + ArchTag_, DataType_, WeightType_, OutputType_, EpilogueSchedule_, EpilogueTag_, FUSION_, \ + CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, MXFPX_, DYNAMIC_CGA_, BIAS_, SWAP_AB_) \ + static void \ + tma_warp_specialized_generic_moe_gemm_kernelLauncher_##ArchTag_##_##DataType_##_##WeightType_##_##OutputType_##_##EpilogueSchedule_##_##EpilogueTag_##_##FUSION_##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##_##MXFPX_##_##DYNAMIC_CGA_##_##BIAS_##_##SWAP_AB_( \ + TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, \ + int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, \ + size_t* workspace_size, cute::Shape dynamic_cluster_shape, \ + cute::Shape fallback_cluster_shape) { \ + using ArchTag = cutlass::arch::ArchTag_; \ + constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \ + constexpr static bool IsMXFPX = MXFPX_; \ + constexpr static bool DYNAMIC_CGA = DYNAMIC_CGA_; \ + constexpr static bool SwapAB = SWAP_AB_; \ + constexpr bool IsBlackwell = ArchTag::kMinComputeCapability >= 100; \ + constexpr static bool IsSM10x = \ + ArchTag::kMinComputeCapability >= 100 && ArchTag::kMinComputeCapability < 120; \ + constexpr static bool IsSM103 = ArchTag::kMinComputeCapability == 103; \ + constexpr bool IsSM120 = \ + ArchTag::kMinComputeCapability == 120 || ArchTag::kMinComputeCapability == 121; \ + /* constexpr static bool BIAS = BIAS_; */ /* Always false */ \ + using T = DataType_; \ + using WeightType = WeightType_; \ + using OutputType = OutputType_; \ + using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \ + using InputClusterShape = \ + cute::Shape, cute::Int, cute::Int>; \ + constexpr static bool Is2SM = IsSM10x && cute::size<0>(InputClusterShape{}) == 2; \ + using ClusterShape = std::conditional_t, \ + InputClusterShape>; \ + using MmaTileShape = cute::Shape, cute::Int, \ + cute::Int>; \ + using InputEpilogueSchedule = EpilogueSchedule_; \ + if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && \ + ArchTag::kMinComputeCapability >= 90 && ArchTag::kMinComputeCapability < 100) { \ + TLLM_THROW( \ + "Please recompile with support for hopper by passing 90-real as an arch to " \ + "build_wheel.py."); \ + } else if constexpr (!COMPILE_BLACKWELL_TMA_GROUPED_GEMMS_ENABLED && \ + ArchTag::kMinComputeCapability >= 100 && \ + ArchTag::kMinComputeCapability < 120) { \ + TLLM_THROW( \ + "Please recompile with support for blackwell by passing 100-real as an arch to " \ + "build_wheel.py."); \ + } else if constexpr (!COMPILE_BLACKWELL_SM120_TMA_GROUPED_GEMMS_ENABLED && \ + ArchTag::kMinComputeCapability >= 120) { \ + TLLM_THROW( \ + "Please recompile with support for blackwell by passing 120-real as an arch to " \ + "build_wheel.py."); \ + } else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v< \ + ArchTag, MmaTileShape, ClusterShape, DYNAMIC_CGA, T>) { \ + TLLM_CHECK_WITH_INFO(SwapAB == tma_ws_input.swap_ab, "SwapAB must match runtime swap_ab"); \ + using namespace cute; \ + /* Helper class for defining all the cutlass types \ + // template \ + // struct TmaWarpSpecializedGroupedGemmInfo \ + { */ \ + constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same::value && \ + cutlass::platform::is_same::value; \ + constexpr static bool IsFP4 = cutlass::platform::is_same::value; \ + static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by SM100"); \ + \ + constexpr static bool IsFP8 = cutlass::platform::is_same::value; \ + \ + /* TODO Update once mixed input support is added */ \ + static_assert(cutlass::platform::is_same::value || IsWFP4AFP8, \ + "TMA warp specialized MOE implementation does not support mixed input types"); \ + \ + constexpr static bool IsBlockScaled = IsFP4 || IsWFP4AFP8; \ + static_assert(!IsBlockScaled || IsBlackwell, "Block scaled is only implemented for SM100"); \ + \ + static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \ + "Unimplemented fusion provided to TMA WS MoE gemm launcher"); \ + constexpr static bool IsFinalizeFusion = FUSION == EpilogueFusion::FINALIZE; \ + constexpr bool IsTmaSM10xEpilogue = \ + std::is_same_v; \ + \ + static_assert(cutlass::platform::is_same::value || \ + cutlass::platform::is_same::value || \ + cutlass::platform::is_same::value || IsFP8 || IsFP4, \ + "Specialized for bfloat16, half, float, fp8, fp4"); \ + \ + /* The cutlass type for the input elements. This is needed to convert to cutlass::half_t if \ + * necessary.*/ \ + using ElementType = typename TllmToCutlassTypeAdapter::type; \ + \ + /* TODO The below never trigger, and are incorrect for int8 types anyway \ + // using CutlassWeightTypeMaybeUint4 = typename \ + TllmToCutlassTypeAdapter::type; \ + // // For legacy reasons we convert unsigned 8-bit to signed \ + // using CutlassWeightTypeMaybeUint8 \ + // = std::conditional_t, cutlass::int4b_t, \ + // CutlassWeightTypeMaybeUint4>; \ + // using CutlassWeightType \ + // = std::conditional_t, int8_t, \ + // CutlassWeightTypeMaybeUint8>; */ \ + using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; \ + \ + using ElementAct = ElementType; \ + using ElementWeight = CutlassWeightType; \ + \ + using ElementD = typename TllmToCutlassTypeAdapter< \ + TmaWarpSpecializedGroupedGemmInput::OutputTypeAdaptor_t>::type; \ + using ElementFinalOutput = typename TllmToCutlassTypeAdapter::type; \ + \ + /* using ElementC = std::conditional_t; */ \ + /* using ElementCSafe = std::conditional_t; */ \ + using ElementC = void; \ + using ElementCSafe = ElementD; \ + \ + using ElementAccumulator = float; \ + \ + using ElementBias = ElementFinalOutput; \ + using ElementRouterScales = float; \ + \ + using ElementSF = std::conditional_t< \ + IsMXFPX, cutlass::float_ue8m0_t, \ + cutlass::float_ue4m3_t>; /*TmaWarpSpecializedGroupedGemmInput::ElementSF;*/ \ + using ElementActBlockScaled = \ + std::conditional_t, \ + cutlass::nv_float4_t>, \ + cute::tuple>; \ + using ElementWeightBlockScaled = \ + std::conditional_t, \ + cutlass::nv_float4_t>, \ + cute::tuple>; \ + \ + /* Activation matrix alignment */ \ + constexpr static int AlignmentAct = \ + 128 / \ + cutlass::sizeof_bits::value; /* Memory access granularity/alignment of A \ + matrix in units of elements (up to 16 bytes) */ \ + /* Weight matrix alignment */ \ + constexpr static int AlignmentWeight = \ + IsWFP4AFP8 \ + ? 128 \ + : (128 / \ + cutlass::sizeof_bits::value); /* Memory access \ + granularity/alignment of B matrix in units \ + // of elements (up to 16 bytes)*/ \ + \ + /* C matrix configuration */ \ + /* Note we use ElementType here deliberately, so we don't break when BIAS is disabled */ \ + constexpr static int AlignmentC = \ + 128 / cutlass::sizeof_bits::value; /* Memory access granularity/alignment \ + of C matrix in \ + // units of elements (up to 16 bytes)*/ \ + \ + /* D matrix configuration */ \ + constexpr static int AlignmentDBits = \ + (IsSM10x && !IsTmaSM10xEpilogue) \ + ? 256 \ + : 128; /* For NoSmem epilogue schedule, we need to align to 256 bits */ \ + constexpr static int AlignmentD = \ + AlignmentDBits / cutlass::sizeof_bits::value; /* Memory access \ + granularity/alignment of D matrix \ + // in units of elements (up to 16 bytes) */ \ + \ + static_assert( \ + cutlass::platform::is_same::value, \ + "TMA Warp Specialized Grouped GEMM specialisation doesn't support fused activation"); \ + \ + using EpilogueOp = \ + cutlass::epilogue::fusion::LinearCombination; \ + \ + /* TODO Add mode for fused activation once CUTLASS adds support \ + // using EpilogueSchedule = cutlass::platform::conditional_t< \ + // cutlass::platform::is_same::value, \ + // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, \ + // cutlass::epilogue::?????????????????? /// <<<<<< what supports \ + activations \ + // >;*/ \ + using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ + \ + using EpilogueScheduleSM10x = std::conditional_t< \ + IsTmaSM10xEpilogue, \ + std::conditional_t, \ + std::conditional_t>; \ + using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \ + using EpilogueSchedule = std::conditional_t< \ + IsSM10x, EpilogueScheduleSM10x, \ + std::conditional_t>; \ + using EpilogueElementC = std::conditional_t; \ + using EpilogueTensorOp = std::conditional_t; \ + using EpilogueScheduleSM10xFinalize = std::conditional_t< \ + !IsFinalizeFusion && IsSM10x, \ + std::conditional_t, \ + EpilogueSchedule>; /* This still needs to be valid when finalize fusion is disabled */ \ + \ + using EpilogueSubTile = std::conditional_t< \ + ArchTag::kMinComputeCapability == 100 && IsFP4 && CTA_N_ == 256, /* SM100 Exactly */ \ + cute::Shape, cutlass::epilogue::collective::EpilogueTileAuto>; \ + \ + using LayoutC = std::conditional_t; \ + using StrideC = std::conditional_t; \ + using LayoutD = std::conditional_t; \ + using StrideD = std::conditional_t; \ + \ + /* Epilogue For Default Finalize */ \ + using CollectiveEpilogueDefault = typename cutlass::epilogue::collective:: \ + CollectiveBuilder< /**/ \ + ArchTag, EpilogueTensorOp, /**/ \ + MmaTileShape, ClusterShape, /**/ \ + EpilogueSubTile, /**/ \ + ElementAccumulator, ElementAccumulator, /**/ \ + EpilogueElementC, LayoutC*, AlignmentC, /**/ \ + ElementD, LayoutD*, AlignmentD, /**/ \ + EpilogueSchedule>::CollectiveOp; \ + \ + /* Epilogue For Fused Finalize */ \ + using EpilogueFusionOp = std::conditional_t< \ + SwapAB, \ + cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter< \ + LayoutD, ElementFinalOutput, ElementAccumulator, ElementBias, ElementRouterScales>, \ + cutlass::epilogue::fusion::ScaledAccPerColBiasPerRowScaleScatter< \ + LayoutD, ElementFinalOutput, ElementAccumulator, ElementBias, ElementRouterScales>>; \ + using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective:: \ + CollectiveBuilder< /**/ \ + ArchTag, EpilogueTensorOp, /**/ \ + MmaTileShape, InputClusterShape, /**/ \ + EpilogueSubTile, /**/ \ + ElementAccumulator, ElementAccumulator, /**/ \ + EpilogueElementC, LayoutC*, AlignmentC, /**/ \ + void, LayoutD*, AlignmentD, /**/ \ + EpilogueScheduleSM10xFinalize, /**/ \ + EpilogueFusionOp /**/ \ + >::CollectiveOp; \ + \ + using CollectiveEpilogue = std::conditional_t; \ + \ + using StageCountAutoCarveout = \ + cutlass::gemm::collective::StageCountAutoCarveout( \ + sizeof(typename CollectiveEpilogue::SharedStorage))>; \ + \ + using KernelScheduleSM90 = std::conditional_t< \ + IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, \ + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>; \ + \ + using KernelSchedule2SmSm100BlockScaled = \ + std::conditional_t; \ + using KernelSchedule1SmSm100BlockScaled = \ + std::conditional_t; \ + \ + /* TRT-LLM uses vector size 16 for block scaled */ \ + using KernelScheduleSM100 = std::conditional_t< \ + Is2SM, \ + std::conditional_t, \ + std::conditional_t>; \ + using KernelScheduleSM103 = std::conditional_t< \ + Is2SM, \ + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103, \ + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103>; \ + using KernelScheduleSM10x = \ + std::conditional_t; \ + using KernelScheduleSM120 = cutlass ::gemm ::collective::KernelScheduleAuto; \ + using KernelScheduleBW = \ + std::conditional_t; \ + \ + using KernelSchedule = \ + std::conditional_t; \ + \ + using TensorOp = std::conditional_t; \ + \ + using MainloopElementAct = \ + std::conditional_t; \ + using MainloopElementWeight = std::conditional_t; \ + using SwappedMainloopElementA = \ + std::conditional_t; \ + using SwappedMainloopElementB = \ + std::conditional_t; \ + constexpr auto SwappedAlignmentA = SwapAB ? AlignmentWeight : AlignmentAct; \ + constexpr auto SwappedAlignmentB = SwapAB ? AlignmentAct : AlignmentWeight; \ + using LayoutA = TmaWarpSpecializedGroupedGemmInput::LayoutA; \ + using LayoutB = TmaWarpSpecializedGroupedGemmInput::LayoutB; \ + using StrideA = typename TmaWarpSpecializedGroupedGemmInput::StrideA; \ + using StrideB = typename TmaWarpSpecializedGroupedGemmInput::StrideB; \ + using CollectiveMainloop = typename cutlass::gemm::collective:: \ + CollectiveBuilder< /**/ \ + ArchTag, TensorOp, /**/ \ + SwappedMainloopElementA, LayoutA*, SwappedAlignmentA, /**/ \ + SwappedMainloopElementB, LayoutB*, SwappedAlignmentB, /**/ \ + ElementAccumulator, /**/ \ + MmaTileShape, ClusterShape, /**/ \ + StageCountAutoCarveout, KernelSchedule>::CollectiveOp; \ + \ + using GemmKernel = \ + cutlass::gemm::kernel::GemmUniversal; \ + \ + using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter; \ + \ + if (kernel_occupancy != nullptr) { \ + TLLM_THROW("TMA WS kernels do not support calculating occupancy"); \ + return; \ + } \ + \ + cutlass::KernelHardwareInfo hw_info; \ + hw_info.device_id = 0; \ + hw_info.sm_count = multi_processor_count; \ + \ + if constexpr (DYNAMIC_CGA) { \ + TLLM_CHECK(cute::size<0>(dynamic_cluster_shape) >= 1); \ + TLLM_CHECK(cute::size<1>(dynamic_cluster_shape) >= 1); \ + TLLM_CHECK(cute::size<0>(fallback_cluster_shape) >= 1); \ + TLLM_CHECK(cute::size<1>(fallback_cluster_shape) >= 1); \ + TLLM_CHECK_WITH_INFO( \ + cute::size<0>(dynamic_cluster_shape) % cute::size<0>(fallback_cluster_shape) == 0, \ + "Dynamic cluster shape (%dx%d) must be divisible by cluster shape (%dx%d)", \ + (int)cute::size<0>(dynamic_cluster_shape), (int)cute::size<1>(dynamic_cluster_shape), \ + (int)cute::size<0>(fallback_cluster_shape), \ + (int)cute::size<1>(fallback_cluster_shape)); \ + TLLM_CHECK_WITH_INFO( \ + cute::size<0>(fallback_cluster_shape) % cute::size<0>(InputClusterShape{}) == 0, \ + "Fallback cluster shape (%dx%d) must be divisible by MMA cluster shape (%dx%d)", \ + (int)cute::size<0>(fallback_cluster_shape), \ + (int)cute::size<1>(fallback_cluster_shape), (int)cute::size<0>(InputClusterShape{}), \ + (int)cute::size<1>(InputClusterShape{})); \ + hw_info.cluster_shape = \ + dim3(cute::size<0>(dynamic_cluster_shape), cute::size<1>(dynamic_cluster_shape), 1); \ + hw_info.cluster_shape_fallback = \ + dim3(cute::size<0>(fallback_cluster_shape), cute::size<1>(fallback_cluster_shape), 1); \ + } \ + GemmGrouped gemm; \ + \ + if (workspace_size != nullptr) { \ + /* Make a mock problem shape with just the minimal information actually required to get \ + the workspace \ + // size This makes some assumptions about CUTLASS's implementation which is suboptimal. We \ + have a check \ + // later to catch future cutlass updates causing silent breakages, but that is not fool \ + proof. The \ + // alternative is to wait until we have data and then dynamically allocate the workspace*/ \ + typename TmaWarpSpecializedGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, \ + nullptr}; \ + \ + typename GemmKernel::TileScheduler::Arguments scheduler_args{ \ + 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ + const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ + shape_info, \ + {}, \ + {}, \ + hw_info, \ + scheduler_args}; \ + *workspace_size = gemm.get_workspace_size(args); \ + return; \ + } \ + \ + using MainloopArguments = typename CollectiveMainloop::Arguments; \ + TLLM_CHECK(tma_ws_input.stride_act); \ + TLLM_CHECK(tma_ws_input.stride_weight); \ + TLLM_CHECK(tma_ws_input.ptr_act); \ + TLLM_CHECK(tma_ws_input.ptr_weight); \ + \ + MainloopArguments const mainloop_args = [&] { \ + if constexpr (IsBlockScaled) { \ + if constexpr (SwapAB) { \ + return construct_if_true<(IsBlockScaled && SwapAB), MainloopArguments>( \ + reinterpret_cast(tma_ws_input.ptr_weight), \ + reinterpret_cast(tma_ws_input.stride_weight), \ + reinterpret_cast(tma_ws_input.ptr_act), \ + reinterpret_cast(tma_ws_input.stride_act), \ + reinterpret_cast( \ + tma_ws_input.fpX_block_scaling_factors_weight), \ + reinterpret_cast())>( \ + tma_ws_input.fpX_block_scaling_factors_stride_weight), \ + reinterpret_cast(tma_ws_input.fpX_block_scaling_factors_act), \ + reinterpret_cast())>( \ + tma_ws_input.fpX_block_scaling_factors_stride_act)); \ + } else { \ + return construct_if_true<(IsBlockScaled && !SwapAB), MainloopArguments>( \ + reinterpret_cast(tma_ws_input.ptr_act), \ + reinterpret_cast(tma_ws_input.stride_act), \ + reinterpret_cast(tma_ws_input.ptr_weight), \ + reinterpret_cast(tma_ws_input.stride_weight), \ + reinterpret_cast(tma_ws_input.fpX_block_scaling_factors_act), \ + reinterpret_cast())>( \ + tma_ws_input.fpX_block_scaling_factors_stride_act), \ + reinterpret_cast( \ + tma_ws_input.fpX_block_scaling_factors_weight), \ + reinterpret_cast())>( \ + tma_ws_input.fpX_block_scaling_factors_stride_weight)); \ + } \ + } else { \ + if constexpr (SwapAB) { \ + return construct_if_true<(!IsBlockScaled && SwapAB), MainloopArguments>( \ + reinterpret_cast(tma_ws_input.ptr_weight), \ + reinterpret_cast(tma_ws_input.stride_weight), \ + reinterpret_cast(tma_ws_input.ptr_act), \ + reinterpret_cast(tma_ws_input.stride_act)); \ + } else { \ + return construct_if_true<(!IsBlockScaled && !SwapAB), MainloopArguments>( \ + reinterpret_cast(tma_ws_input.ptr_act), \ + reinterpret_cast(tma_ws_input.stride_act), \ + reinterpret_cast(tma_ws_input.ptr_weight), \ + reinterpret_cast(tma_ws_input.stride_weight)); \ + } \ + } \ + }(); \ + using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ + using EpilogueScalars = decltype(EpilogueArguments{}.thread); \ + EpilogueScalars epilogue_scalars = [&] { \ + constexpr bool IsSimpleAlphaBeta = \ + std::is_constructible_v; \ + if constexpr (IsFinalizeFusion) { \ + auto epi_params = tma_ws_input.fused_finalize_epilogue; \ + if constexpr (SwapAB) { \ + return construct_if_true<(FUSION == EpilogueFusion::FINALIZE && SwapAB), \ + EpilogueScalars>( \ + ElementAccumulator(1), nullptr, tma_ws_input.alpha_scale_ptr_array, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */ \ + reinterpret_cast(epi_params.ptr_bias), \ + Stride<_1, _0, int64_t>{}, /* bias */ \ + epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */ \ + reinterpret_cast(epi_params.ptr_final_output), \ + epi_params.stride_final_output_transposed, epi_params.ptr_source_token_index, \ + epi_params.num_rows_in_final_output, epi_params.shape_override, \ + epi_params.use_reduction); \ + } else { \ + return construct_if_true<(FUSION == EpilogueFusion::FINALIZE && !SwapAB), \ + EpilogueScalars>( \ + ElementAccumulator(1), nullptr, tma_ws_input.alpha_scale_ptr_array, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */ \ + reinterpret_cast(epi_params.ptr_bias), \ + Stride<_0, _1, int64_t>{}, /* bias */ \ + epi_params.ptr_router_scales, Stride<_1, _0, int64_t>{}, /* scale */ \ + reinterpret_cast(epi_params.ptr_final_output), \ + epi_params.stride_final_output, epi_params.ptr_source_token_index, \ + epi_params.num_rows_in_final_output, epi_params.shape_override, \ + epi_params.use_reduction); \ + } \ + } else if constexpr (!IsSimpleAlphaBeta) { \ + return construct_if_true<(!IsSimpleAlphaBeta && !IsFinalizeFusion), EpilogueScalars>( \ + ElementAccumulator(1.f), \ + tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f), nullptr, \ + nullptr, tma_ws_input.alpha_scale_ptr_array, nullptr, \ + cute::Shape<_0, _0, int64_t>{ \ + cute::_0{}, cute::_0{}, \ + (tma_ws_input.alpha_scale_ptr_array != nullptr) ? 1 : 0}, \ + cute::Shape<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 0}); \ + } else if (tma_ws_input.alpha_scale_ptr_array) { \ + return construct_if_true<(IsSimpleAlphaBeta && !IsFinalizeFusion), EpilogueScalars>( \ + tma_ws_input.alpha_scale_ptr_array); \ + } else { \ + return construct_if_true<(IsSimpleAlphaBeta && !IsFinalizeFusion), EpilogueScalars>( \ + ElementAccumulator(1.f), \ + tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); \ + } \ + }(); \ + \ + EpilogueArguments epilogue_args = [&] { \ + if constexpr (FUSION == EpilogueFusion::FINALIZE) { \ + return construct_if_true < FUSION == EpilogueFusion::FINALIZE, \ + EpilogueArguments > (epilogue_scalars, nullptr, nullptr, nullptr, nullptr); \ + } else { \ + return construct_if_true < FUSION != EpilogueFusion::FINALIZE, \ + EpilogueArguments > (epilogue_scalars, nullptr, nullptr, \ + reinterpret_cast(tma_ws_input.ptr_d), \ + reinterpret_cast(tma_ws_input.stride_d)); \ + } \ + }(); \ + \ + typename GemmKernel::TileScheduler::Arguments scheduler_args{ \ + 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ + \ + const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ + tma_ws_input.shape_info, \ + mainloop_args, \ + epilogue_args, \ + hw_info, \ + scheduler_args}; \ + \ + size_t calculated_ws_size = gemm.get_workspace_size(args); \ + TLLM_CHECK_WITH_INFO(calculated_ws_size <= tma_ws_input.gemm_workspace_size, \ + "Workspace is size %zu but only %zu were allocated", \ + calculated_ws_size, tma_ws_input.gemm_workspace_size); \ + \ + auto can_implement = gemm.can_implement(args); \ + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, \ + "Grouped GEMM kernel will fail for params. Error: " + \ + std::string(cutlass::cutlassGetStatusString(can_implement))); \ + \ + auto init_status = gemm.initialize(args, tma_ws_input.gemm_workspace); \ + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, \ + "Failed to initialize cutlass TMA WS grouped gemm. Error: " + \ + std::string(cutlass::cutlassGetStatusString(init_status))); \ + auto run_status = gemm.run(stream, nullptr, tma_ws_input.enable_pdl); \ + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, \ + "Failed to run cutlass TMA WS grouped gemm. Error: " + \ + std::string(cutlass::cutlassGetStatusString(run_status))); \ + sync_check_cuda_error(stream); \ + } else { \ + TLLM_THROW("Configuration was disabled by FAST_BUILD"); \ + } \ + \ + return; \ + } \ + \ + template <> \ + struct DispatchToTmaWSFunction< \ + cutlass::arch::ArchTag_, DataType_, WeightType_, OutputType_, EpilogueSchedule_, \ + tensorrt_llm::cutlass_extensions::EpilogueTag_, EpilogueFusion::FUSION_, \ + cute::Shape, cute::Int, cute::Int>, \ + cute::Shape, cute::Int, cute::Int>, MXFPX_, DYNAMIC_CGA_, \ + BIAS_, SWAP_AB_> { \ + constexpr static auto* op = &tma_warp_specialized_generic_moe_gemm_kernelLauncher_##ArchTag_##_##DataType_##_##WeightType_##_##OutputType_##_##EpilogueSchedule_##_##EpilogueTag_##_##FUSION_##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##_##MXFPX_##_##DYNAMIC_CGA_##_##BIAS_##_##SWAP_AB_; \ + }; \ + template void tma_warp_specialized_generic_moe_gemm_kernelLauncher< \ + cutlass::arch::ArchTag_, DataType_, WeightType_, OutputType_, EpilogueSchedule_, \ + tensorrt_llm::cutlass_extensions::EpilogueTag_, EpilogueFusion::FUSION_, \ + cute::Shape, cute::Int, cute::Int>, \ + cute::Shape, cute::Int, cute::Int>, MXFPX_, DYNAMIC_CGA_, \ + BIAS_, SWAP_AB_>(TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, \ + int const multi_processor_count, cudaStream_t stream, \ + int* kernel_occupancy, size_t* workspace_size, \ + cute::Shape dynamic_cluster_shape, \ + cute::Shape fallback_cluster_shape); -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h index 16ebddca32..91d12ef0e7 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h @@ -1,32 +1,39 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include +#include "../../include/moe_gemm_kernels.h" #include "cutlass_extensions/gemm_configs.h" #include "cutlass_extensions/weight_only_quant_op.h" -#include "moe_gemm_kernels.h" namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels { - +namespace cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput; +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; template void sm90_generic_mixed_moe_gemm_kernelLauncher( - GroupedGemmInput inputs, + tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput + inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size); -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl index e28cb7b129..8f4d2f7630 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl @@ -1,13 +1,17 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #ifdef __GNUC__ // Check if the compiler is GCC or Clang @@ -44,28 +48,30 @@ #pragma GCC diagnostic pop #endif // __GNUC__ +#include "moe_gemm_tma_ws_mixed_input_launcher.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h" +namespace tensorrt_llm { +namespace kernels { +namespace cutlass_kernels_oss { +using namespace tensorrt_llm::kernels::cutlass_kernels; namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; using namespace cute; -namespace tensorrt_llm { -namespace kernels { -namespace cutlass_kernels { - template void sm90_generic_mixed_moe_gemm_kernelLauncher( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -181,40 +187,36 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher( hw_info.device_id = 0; hw_info.sm_count = sm_count_; - if (workspace_size != nullptr) { - const Args args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {inputs.num_experts, hopper_inputs.int4_groupwise_params.shape.problem_shapes, nullptr}, - {reinterpret_cast(hopper_inputs.ptr_b), hopper_inputs.stride_b, - reinterpret_cast(hopper_inputs.ptr_a), hopper_inputs.stride_a, - reinterpret_cast(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, - {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), - hopper_inputs.stride_c, reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, - hw_info}; - *workspace_size = gemm.get_workspace_size(args); - return; - } - - assert(group_size == int(inputs.groupwise_quant_group_size)); arguments = Args{ cutlass::gemm::GemmUniversalMode::kGrouped, {inputs.num_experts, hopper_inputs.int4_groupwise_params.shape.problem_shapes, nullptr}, - {reinterpret_cast(hopper_inputs.ptr_b), hopper_inputs.stride_b, - reinterpret_cast(hopper_inputs.ptr_a), hopper_inputs.stride_a, + {reinterpret_cast(hopper_inputs.ptr_weight), + reinterpret_cast(hopper_inputs.stride_weight), + reinterpret_cast(hopper_inputs.ptr_act), + reinterpret_cast(hopper_inputs.stride_act), reinterpret_cast(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, - {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), hopper_inputs.stride_c, - reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, + reinterpret_cast(hopper_inputs.int4_groupwise_params.stride_s_a), group_size}, + {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), + reinterpret_cast(hopper_inputs.stride_c), + reinterpret_cast(hopper_inputs.ptr_d), + reinterpret_cast(hopper_inputs.stride_d)}, hw_info}; + assert(group_size == int(inputs.groupwise_quant_group_size)); + if (workspace_size != nullptr) { + *workspace_size = gemm.get_workspace_size(arguments); + return; + } + if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) { TLLM_LOG_ERROR("[Mixed dtype WS grouped GEMM] given workspace size insufficient, %d < %d.", gemm.get_workspace_size(arguments), hopper_inputs.gemm_workspace_size); } + // This is not initialized during workspace size calculation so check after + TLLM_CHECK_WITH_INFO(hopper_inputs.swap_ab, + "swap_ab must be true for mixed dtype WS grouped GEMM"); + auto can_implement = gemm.can_implement(arguments); if (can_implement != cutlass::Status::kSuccess) { std::string err_msg = "mixed dtype WS grouped cutlass kernel will fail for params. Error: " + @@ -239,6 +241,6 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher( return; } -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu index 1a350efc15..1072cdd1fa 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_BF16 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu index 3d020f2618..c1d40e33ac 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #if defined(ENABLE_BF16) && defined(ENABLE_FP4) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu index 8fd27b4c3f..da1adbac53 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_BF16 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu index 98ec5e7a64..b10a7f6713 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_BF16 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu index 94ed59c0a6..cbf13d5f6f 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_BF16 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu index a3af6d6c8a..aba083585f 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { template class MoeGemmRunner; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu index c4533161bd..ce4b57cc69 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #if defined(ENABLE_FP4) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu index 9a464b8311..d216bed89c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { template class MoeGemmRunner; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu index 92159da7ed..b9bdae53ac 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { template class MoeGemmRunner; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu index d26e9609fd..747f9b29e8 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { template class MoeGemmRunner; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu index a0137fd1c6..a8c11e0692 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP4 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu index c00b77dbc1..6bc740c5fa 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP4 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu index 7235cb5119..08b7ce1930 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP8 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu index 01a096b526..7dbf9f6265 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP8 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 16b39246cb..708406aab6 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -67,7 +67,7 @@ #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { // ============================= Variable batched Gemm things =========================== template ::type; - using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter::type; - using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; + using ElementType = typename cutlass_kernels::TllmToCutlassTypeAdapter::type; + using CutlassGemmOutputType = + typename cutlass_kernels::TllmToCutlassTypeAdapter::type; + using CutlassWeightType = typename cutlass_kernels::TllmToCutlassTypeAdapter::type; if (!inputs.use_fused_moe) { // We need separate config for each architecture since we will target different tensorcore // instructions. For float, we do not target TCs. @@ -213,9 +214,9 @@ struct genericMoeGemmKernelLauncher { // support fp16 or // bf16) { - sm80_generic_fused_moe_gemm_kernelLauncher( + tensorrt_llm::kernels::cutlass_kernels_oss::sm80_generic_fused_moe_gemm_kernelLauncher< + ElementType, CutlassWeightType, ThreadblockShape::kM, ThreadblockShape::kN, + ThreadblockShape::kK, Stages, EpilogueTag>( reinterpret_cast(inputs.A), reinterpret_cast(inputs.B), reinterpret_cast(inputs.biases), inputs.bias_is_broadcast, @@ -254,18 +255,19 @@ static void dispatch(GroupedGemmInput= 80) && (!isFp8 || std::is_same_v) && !isFp4) { // dispatch for quant op type - auto* launcher = kernels::cutlass_kernels::genericMoeGemmKernelLauncher< + auto* launcher = tensorrt_llm::kernels::cutlass_kernels_oss::genericMoeGemmKernelLauncher< T, WeightType, GemmOutputType, Arch, cutlass::WeightOnlyQuantOp::UNDEFINED, EpilogueTag, ThreadblockShape, WarpShape, Stages>::call; if (!std::is_same_v && inputs.groupwise_quant_group_size > 0) { - launcher = inputs.zeros ? kernels::cutlass_kernels::genericMoeGemmKernelLauncher< - T, WeightType, GemmOutputType, Arch, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, - EpilogueTag, ThreadblockShape, WarpShape, Stages>::call - : kernels::cutlass_kernels::genericMoeGemmKernelLauncher< - T, WeightType, GemmOutputType, Arch, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, EpilogueTag, - ThreadblockShape, WarpShape, Stages>::call; + launcher = inputs.zeros + ? tensorrt_llm::kernels::cutlass_kernels_oss::genericMoeGemmKernelLauncher< + T, WeightType, GemmOutputType, Arch, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, EpilogueTag, + ThreadblockShape, WarpShape, Stages>::call + : tensorrt_llm::kernels::cutlass_kernels_oss::genericMoeGemmKernelLauncher< + T, WeightType, GemmOutputType, Arch, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, EpilogueTag, + ThreadblockShape, WarpShape, Stages>::call; } launcher(inputs, sm_count_); } else { @@ -519,17 +521,23 @@ void dispatchMoeGemmToCutlass( } } +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss + +namespace tensorrt_llm::kernels::cutlass_kernels { + template std::vector -MoeGemmRunner::getConfigs() const { - return getConfigs(sm_); +MoeGemmRunner::getConfigs( + bool supports_finalize_fusion) const { + return getConfigs(sm_, supports_finalize_fusion); } template std::vector -MoeGemmRunner::getConfigs(int sm) { +MoeGemmRunner::getConfigs(int sm, + bool supports_finalize_fusion) { std::vector candidate_configs = - getTmaWarpSpecializedConfigs(sm); + getTmaWarpSpecializedConfigs(sm, supports_finalize_fusion); std::vector ampere_configs = getAmpereConfigs(sm); std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs)); return candidate_configs; @@ -552,19 +560,21 @@ MoeGemmRunner::getAmpereConfigs(int sm auto config_type_param = static_cast( weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation() || + if (!tensorrt_llm::kernels::cutlass_kernels::isValidAmpereMOESpecialisation() || (use_w4afp8 && sm != 89) || use_wfp4a16) { return {}; } std::vector ampere_configs = - kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, + config_type_param); return ampere_configs; } template std::vector -MoeGemmRunner::getTmaWarpSpecializedConfigs(int sm) { +MoeGemmRunner::getTmaWarpSpecializedConfigs( + int sm, bool supports_finalize_fusion) { using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; static constexpr auto weight_only_flag = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; @@ -577,28 +587,32 @@ MoeGemmRunner::getTmaWarpSpecializedCo static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; static constexpr auto fp4_only_flag = - (use_fp4 || use_wfp4afp4) ? CutlassGemmConfig::FP4_ONLY : CutlassGemmConfig::NONE; + (use_fp4 || use_wfp4afp8) ? CutlassGemmConfig::FP4_ONLY : CutlassGemmConfig::NONE; + static constexpr auto fp8fp4_mixed_flag = + use_wfp4afp8 ? CutlassGemmConfig::FP8FP4_MIXED : CutlassGemmConfig::NONE; auto config_type_param = static_cast( weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_blackwell | enable_hopper | - fp8_only_flag | fp4_only_flag); + fp8_only_flag | fp4_only_flag | fp8fp4_mixed_flag); TLLM_CHECK_WITH_INFO(!(enable_blackwell && enable_hopper), "Blackwell and hopper flags are mutually exclusive"); + sm = use_wfp4afp8 && sm == 103 ? 100 : sm; if (sm >= 100 && sm < 120 && - !kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) { + !tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) { TLLM_LOG_TRACE( "Blackwell is not supported for this configuration, not selecting any TMA WS " "implementations"); return {}; } if ((sm == 120 || sm == 121) && - !kernels::cutlass_kernels::isValidSM120MOESpecialisation()) { + !tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation()) { TLLM_LOG_TRACE( "Blackwell SM120 is not supported for this configuration, not selecting any TMA WS " "implementations"); return {}; } - if (enable_hopper && !kernels::cutlass_kernels::isValidHopperMOESpecialisation()) { + if (enable_hopper && + !tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation()) { TLLM_LOG_TRACE( "Hopper is not supported for this configuration, not selecting any TMA WS implementations"); return {}; @@ -606,6 +620,51 @@ MoeGemmRunner::getTmaWarpSpecializedCo std::vector tma_ws_configs = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + + if (sm == 103 && use_fp4) { + // Explicitly select SM100 as well + auto sm100_configs = tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs( + 100, max_split_k, config_type_param); + std::copy(sm100_configs.begin(), sm100_configs.end(), std::back_inserter(tma_ws_configs)); + } + + if (supports_finalize_fusion) { + // Duplicate the configs and set the epilogue fusion type to FINALIZE + auto finalize_configs = tma_ws_configs; + std::transform(finalize_configs.begin(), finalize_configs.end(), + std::back_inserter(tma_ws_configs), [](auto& config) { + config.epilogue_fusion_type = + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + return config; + }); + + // Finalize fusion is only supported for TMA epilogue schedule + tma_ws_configs.erase( + std::remove_if( + tma_ws_configs.begin(), tma_ws_configs.end(), + [](auto& config) { + return config.epilogue_fusion_type == + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE && + config.epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM; + }), + tma_ws_configs.end()); + } + + auto swap_ab_configs = tma_ws_configs; + std::transform(swap_ab_configs.begin(), swap_ab_configs.end(), std::back_inserter(tma_ws_configs), + [](auto& config) { + TLLM_CHECK_WITH_INFO(!config.swap_ab, "Swap AB is already set"); + config.swap_ab = true; + return config; + }); + + if (use_w4_groupwise) { + // w4 groupwise implementation requires swap_ab to be true + tma_ws_configs.erase(std::remove_if(tma_ws_configs.begin(), tma_ws_configs.end(), + [](auto& config) { return !config.swap_ab; }), + tma_ws_configs.end()); + } + return tma_ws_configs; } @@ -617,12 +676,15 @@ bool MoeGemmRunner::isTmaWarpSpecializ } template -bool MoeGemmRunner::supportsTmaWarpSpecialized() const { - return (sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation()) || - (sm_ >= 100 && sm_ < 120 && - kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) || - ((sm_ == 120 || sm_ == 121) && - kernels::cutlass_kernels::isValidSM120MOESpecialisation()); +bool MoeGemmRunner::supportsTmaWarpSpecialized(int sm) { + return (sm == 90 && + tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation()) || + (sm >= 100 && sm < 120 && + tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< + T, WeightType>()) || + ((sm == 120 || sm == 121) && + tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation()); } template @@ -677,63 +739,64 @@ void MoeGemmRunner::dispatchToArch( "Hopper configuration provided for non-Hopper architecture"); if (sm_ >= 75 && sm_ < 80) { -#ifdef ENABLE_FP4 - if constexpr (!std::is_same_v) { - dispatchMoeGemmToCutlass( +#if defined(ENABLE_FP4) + constexpr bool is_fp4 = std::is_same_v; +#else + constexpr bool is_fp4 = false; +#endif + if constexpr (!is_fp4) { + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } else { TLLM_THROW("FP4 data type is not supported on SM < 90"); } + } else if (sm_ >= 80 && sm_ < 90) { +#if defined(ENABLE_FP4) + constexpr bool is_fp4 = std::is_same_v; #else - TLLM_THROW("FP4 data type is not supported on SM < 90"); + constexpr bool is_fp4 = false; #endif - } else if (sm_ >= 80 && sm_ < 90) { - if constexpr (use_fp8 || use_w4afp8) { + if constexpr (!is_fp4) { + if constexpr (use_fp8 || use_w4afp8) { #if defined(ENABLE_FP8) - static_assert( - !std::is_same_v && !std::is_same_v, - "FP8 GEMM Output not supported"); + static_assert(!std::is_same_v && + !std::is_same_v, + "FP8 GEMM Output not supported"); #endif - TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); - } else { -#ifdef ENABLE_FP4 - if constexpr (std::is_same_v) { - TLLM_THROW("FP4 data type is not supported on SM < 90"); + + TLLM_CHECK_WITH_INFO(sm_ == 89, + "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + cutlass_kernels_oss::dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); } else { - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } -#else - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); -#endif + } else { + TLLM_THROW("FP4 data type is not supported on SM < 90"); } } else if (sm_ >= 90) { - // For SM120+ FP8 MoE, redirect to SM89 (Ada) FP8 kernel implementations. - if constexpr (use_fp8) { + // For SM120+ pure FP8 MoE (not FP8 x FP4), redirect to SM89 (Ada) FP8 kernel implementations. + if constexpr (use_fp8 && !use_wfp4afp8) { if (sm_ >= 120) { - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); return; } } - if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation< - T, WeightType, EpilogueTag>() && + if constexpr (tensorrt_llm::kernels::cutlass_kernels:: + isValidTmaWarpSpecializedMOESpecialisation() && !use_w4_groupwise) { // We allow both tma warp specialized and SM80 configurations to coexist because for some // cases with small numbers of tokens SM80 is faster. We check here to see which is selected if (inputs.gemm_config.sm_version >= 90) { - bool is_same_sm = inputs.gemm_config.sm_version == sm_; - // gemm_config.sm_version indicates the kernel pipeline, which is always 100 for 100, 103, - // 110 below logging helps confirming the cutlass pipeline matches the device major version - bool is_sm110 = inputs.gemm_config.sm_version == 100 && sm_ == 110; - bool is_sm103 = inputs.gemm_config.sm_version == 100 && sm_ == 103; - // SM120 and SM121 are architecturally identical - bool is_sm120 = (inputs.gemm_config.sm_version == 120) && (sm_ == 120 || sm_ == 121); - TLLM_CHECK_WITH_INFO(is_same_sm || is_sm110 || is_sm103 || is_sm120, + // Check the major version of the SM matches + TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version / 10 == sm_ / 10, "Using SM %d configuration for SM %d device", inputs.gemm_config.sm_version, sm_); TLLM_CHECK_WITH_INFO(inputs.biases != nullptr || hopper_inputs.ptr_c == nullptr, @@ -746,11 +809,11 @@ void MoeGemmRunner::dispatchToArch( auto select_function = [&]() { switch (hopper_inputs.fusion) { case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE: - return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized< + return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized< T, WeightType, OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE>; case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE: - return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized< + return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized< T, WeightType, OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>; case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::ACTIVATION: @@ -775,16 +838,16 @@ void MoeGemmRunner::dispatchToArch( "w4afp8 is only supported for TMA warp specialization"); // EpilogueTag is ignored if (inputs.k % 512 == 0) { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass< + T, WeightType, ScaleBiasType, cutlass_extensions::EpilogueOpDefault, 4>( inputs, hopper_inputs, multi_processor_count_, nullptr); } else if (inputs.k % 256 == 0) { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass< + T, WeightType, ScaleBiasType, cutlass_extensions::EpilogueOpDefault, 2>( inputs, hopper_inputs, multi_processor_count_, nullptr); } else if (inputs.k % 128 == 0) { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass< + T, WeightType, ScaleBiasType, cutlass_extensions::EpilogueOpDefault, 1>( inputs, hopper_inputs, multi_processor_count_, nullptr); } else { TLLM_THROW("Invalid GEMM K size %d", (int)inputs.k); @@ -796,16 +859,16 @@ void MoeGemmRunner::dispatchToArch( TLLM_CHECK_WITH_INFO(inputs.gemm_config.is_tma_warp_specialized, "wfp4a16 is only supported for TMA warp specialization"); // EpilogueTag is ignored - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass< + T, WeightType, ScaleBiasType, cutlass_extensions::EpilogueOpDefault, 1>( inputs, hopper_inputs, multi_processor_count_, nullptr); return; } #endif // Do Ampere case instead - if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) { + if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidAmpereMOESpecialisation< + T, WeightType, EpilogueTag>()) { TLLM_CHECK_WITH_INFO(!use_fp8, "No fallback FP8 implementation available"); TLLM_CHECK_WITH_INFO(use_w4afp8 || !hopper_inputs.isValid(), "Non-specialized Hopper implementation is being rerouted to fallback " @@ -818,10 +881,12 @@ void MoeGemmRunner::dispatchToArch( "Using SM %d configuration for SM80 fallback implementation", inputs.gemm_config.sm_version); if constexpr (use_fp8) { - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } else { - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } } else { @@ -848,18 +913,21 @@ template ::calcMaxWorkspaceSize( int num_experts) const { if constexpr (use_w4_groupwise) { - return calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( + return cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( num_experts, multi_processor_count_); } if (!supportsTmaWarpSpecialized()) { return 0; } - if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation< + if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation< T, WeightType>() && !use_w4afp8 && !use_wfp4a16) { - auto configs = getTmaWarpSpecializedConfigs(sm_); + // Finalize fusion may not actually be supported by the kernel, + // if they are not we will catch the error and skip them + auto configs = getTmaWarpSpecializedConfigs(sm_, true); auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; - if constexpr (use_wfp4afp4) { + if constexpr (use_wfp4afp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; } else if (use_fp4) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; @@ -867,17 +935,19 @@ size_t MoeGemmRunner::calcMaxWorkspace size_t max_size = 0; bool has_config = false; for (auto conf : configs) { -#define CALC_SIZE_FUSION(FUSION) \ - do { \ - try { \ - size_t size = calcMaxWorkspaceSizeTmaWarpSpecialized( \ - num_experts, conf, multi_processor_count_, fpX_block_scaling_type); \ - max_size = std::max(max_size, size); \ - has_config = true; \ - } catch (tensorrt_llm::common::TllmException const& e) { \ - TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size %s", \ - e.what()); \ - } \ +#define CALC_SIZE_FUSION(FUSION) \ + do { \ + try { \ + size_t size = \ + cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecialized( \ + num_experts, conf, multi_processor_count_, fpX_block_scaling_type); \ + max_size = std::max(max_size, size); \ + has_config = true; \ + } catch (tensorrt_llm::common::TllmException const& e) { \ + TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size %s", \ + e.what()); \ + } \ } while (0) CALC_SIZE_FUSION(TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); @@ -927,9 +997,6 @@ void MoeGemmRunner::moeGemmBiasAct( case ActivationType::Geglu: runGemm(inputs, hopper_inputs); break; - case ActivationType::Relu2: - TLLM_THROW("Relu2 is not supported."); - break; case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index c764cb6c90..5adacd0ce2 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -49,6 +49,7 @@ #include #include +#include #include #include "../include/moe_gemm_kernels.h" @@ -59,15 +60,63 @@ #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion; +template +auto getDispatchFunctionForSM100(cutlass_extensions::EpilogueScheduleType epilogue_schedule, + bool dynamic_cga, bool swap_ab) { + auto select_swap_ab = [dynamic_cga, epilogue_schedule](auto swap_ab_t) { + auto select_dynamic_cga = [epilogue_schedule](auto dynamic_cga_t) { +#if defined(ENABLE_FP4) + constexpr bool is_block_scaled = + std::is_same_v || std::is_same_v; +#else + constexpr bool is_block_scaled = false; +#endif + if constexpr ((!is_block_scaled || Arch::kMinComputeCapability == 103) && + FUSION != EpilogueFusion::FINALIZE) { + auto func_map = std::array{ + &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, + EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value>, + &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayTmaWarpSpecialized, + EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value> + + }; + bool const tma_epilogue = + epilogue_schedule == cutlass_extensions::EpilogueScheduleType::TMA; + return func_map[tma_epilogue]; + } else { + static_assert(FUSION == EpilogueFusion::FINALIZE || Arch::kMinComputeCapability != 103, + "SM103 should support both epilogue schedules"); + TLLM_CHECK_WITH_INFO( + epilogue_schedule == cutlass_extensions::EpilogueScheduleType::TMA, + "No Smem epilogue schedule is not supported for block scaled types or finalize fusion"); + return &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayTmaWarpSpecialized, + EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value>; + } + }; + return dynamic_cga ? select_dynamic_cga(tensorrt_llm::common::ConstBool{}) + : select_dynamic_cga(tensorrt_llm::common::ConstBool{}); + }; + return swap_ab ? select_swap_ab(tensorrt_llm::common::ConstBool{}) + : select_swap_ab(tensorrt_llm::common::ConstBool{}); +} + template -void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmInput hopper_input, - int num_experts, int multi_processor_count, - cudaStream_t stream, int* occupancy, - size_t* workspace_size) { +void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( + TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, + cudaStream_t stream, int* occupancy, size_t* workspace_size) { static_assert( (Arch::kMinComputeCapability == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation()) || @@ -79,15 +128,6 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn TLLM_CHECK_WITH_INFO(workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information"); - // auto func = hopper_input.ptr_c ? - // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper - // : - // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper; - // TODO Re-enable bias when CUTLASS supports it - if constexpr (Arch::kMinComputeCapability < 90) { TLLM_THROW("Invalid architecture instantiated"); } @@ -98,6 +138,13 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn "build_wheel.py."); } #endif +#ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS + else if constexpr (Arch::kMinComputeCapability == 103) { + TLLM_THROW( + "Please recompile with support for blackwell by passing 103-real as an arch to " + "build_wheel.py."); + } +#endif #ifndef COMPILE_BLACKWELL_TMA_GROUPED_GEMMS else if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { TLLM_THROW( @@ -113,39 +160,91 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn } #endif else { -#ifdef ENABLE_FP4 - auto getFunc = [&]() { - if constexpr (std::is_same_v && std::is_same_v) { - TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, - "MXFPX is the only supported scaling type for WFP4AFP8"); - return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, true, - false>; - } else { - TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, - "MXFPX is not supported for the selected weight combination"); - return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, false, - false>; - } - }; - getFunc()(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); +#if defined(ENABLE_FP4) + constexpr static bool is_wfp4afp8 = + std::is_same_v && std::is_same_v; #else - TLLM_THROW("FP4 data type is not supported on this architecture and CUDA version"); + constexpr static bool is_wfp4afp8 = false; #endif + if constexpr (is_wfp4afp8) { + TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, + "MXFPX is the only supported scaling type for WFP4AFP8"); + } else { + TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, + "MXFPX is not supported for the selected weight combination"); + } + + if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { + bool const dynamic_cga = + gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; + bool const swap_ab = hopper_input.swap_ab; + auto cluster_shape = + cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape); + auto cluster_shape_cute = cute::Shape{ + std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}}; + auto cluster_shape_fallback = + cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape); + auto cluster_shape_cute_fallback = cute::Shape{ + std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; + + // HACK debug the gemm_config used to produce selected_func + // std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version + // << ", tile_config_sm100=" << static_cast(gemm_config.tile_config_sm100) + // << ", epilogue_schedule=" << static_cast(gemm_config.epilogue_schedule) + // << ", dynamic_cluster_shape=" << + // static_cast(gemm_config.dynamic_cluster_shape) + // << ", fallback_cluster_shape=" + // << static_cast(gemm_config.fallback_cluster_shape) << std::endl; + + auto selected_func = + getDispatchFunctionForSM100( + gemm_config.epilogue_schedule, dynamic_cga, swap_ab); + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); + } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) { + using EpilogueSchedule = void; // These are hardcoded in the launcher + constexpr bool dynamic_cga = false; + auto selected_func = + hopper_input.swap_ab + ? kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, is_wfp4afp8, dynamic_cga, false, true> + : kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, is_wfp4afp8, dynamic_cga, false, false>; + + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, {}, {}); + } } } -template +template constexpr bool are_tile_shapes_supported_sm100() { + // We use a runtime cluster shape for SM100, so we only support 1x1x1 and 2x1x1 cluster shapes. + if (cute::size<0>(ClusterShape{}) > 2 || cute::size<1>(ClusterShape{}) != 1 || + cute::size<2>(ClusterShape{}) != 1) { + return false; + } + using namespace cute; - using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // This is the epilogue shape. The MMA shape will be twice this for 2SM constexpr auto TileM = size<0>(CtaShape{}); constexpr auto TileN = size<1>(CtaShape{}); + if constexpr (Arch::kMinComputeCapability == 103) { +#if defined(ENABLE_FP4) + return std::is_same_v && std::is_same_v && + TileM == 128 && (TileN == 128 || TileN == 256); +#else + return false; +#endif + } + if constexpr (TileM != 64 && TileM != 128) { return false; } @@ -181,14 +280,13 @@ constexpr bool are_tile_shapes_supported_sm100() { return true; } -template +template constexpr bool are_tile_shapes_supported_sm120() { using namespace cute; if constexpr (cute::size<0>(ClusterShape{}) != 1 || cute::size<1>(ClusterShape{}) != 1 || cute::size<2>(ClusterShape{}) != 1) { return false; } - using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // This is the epilogue shape. The MMA shape will be twice this for 2SM constexpr auto TileM = size<0>(CtaShape{}); constexpr auto TileN = size<1>(CtaShape{}); @@ -216,7 +314,7 @@ template constexpr bool are_tile_shapes_supported() { if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { - return are_tile_shapes_supported_sm100(); + return are_tile_shapes_supported_sm100(); } else if constexpr (Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121) { return are_tile_shapes_supported_sm120(); } @@ -247,14 +345,16 @@ void dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, size_t* workspace_size) { using namespace cute; + // This uses the fallback cluster shape for sm100 if a dynamic cluster shape is requested. switch (gemm_config.cluster_shape) { #define SHAPE_CASE(M, N, K) \ case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: { \ using ClusterShape = Shape<_##M, _##N, _##K>; \ if constexpr (are_tile_shapes_supported()) { \ - dispatchMoeGemmSelectBiasTmaWarpSpecialized( \ - hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \ + dispatchMoeGemmFinalDispatchTmaWarpSpecialized( \ + hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, \ + workspace_size); \ break; \ } else { \ TLLM_THROW( \ @@ -275,7 +375,8 @@ void dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( #undef SHAPE_CASE default: - TLLM_THROW("Unsupported config %d for MoE gemm.", (int)gemm_config.cluster_shape); + TLLM_THROW("Unsupported cluster shape config %d for MoE gemm.", + (int)gemm_config.cluster_shape); } } // namespace tensorrt_llm @@ -301,15 +402,16 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( workspace_size); \ break; \ } -#define DEFAULT_CASE(SMVERSION) \ - case cutlass_extensions::CutlassTileConfigSM##SMVERSION::Undefined: \ - TLLM_THROW("GEMM config undefined."); \ - break; \ - case cutlass_extensions::CutlassTileConfigSM##SMVERSION::ChooseWithHeuristic: \ - TLLM_THROW("GEMM config should have already been set by heuristic."); \ - break; \ - default: \ - TLLM_THROW("Unsupported config %d for MoE gemm.", (int)gemm_config.tile_config_sm##SMVERSION); \ +#define DEFAULT_CASE(SMVERSION) \ + case cutlass_extensions::CutlassTileConfigSM##SMVERSION::Undefined: \ + TLLM_THROW("GEMM config undefined."); \ + break; \ + case cutlass_extensions::CutlassTileConfigSM##SMVERSION::ChooseWithHeuristic: \ + TLLM_THROW("GEMM config should have already been set by heuristic."); \ + break; \ + default: \ + TLLM_THROW("Unsupported tile shape config %d for MoE gemm.", \ + (int)gemm_config.tile_config_sm##SMVERSION); \ break; if (gemm_config.sm_version == 90) { @@ -327,29 +429,29 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( } else { TLLM_THROW("Unsupported SM90 configuration requested"); } - } else if (gemm_config.sm_version == 110) { + } +#if defined(ENABLE_FP4) && defined(COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS) + // Check this before SM100 because we fall back to SM100 if not NVFP4 + else if (gemm_config.sm_version == 103 && std::is_same_v && + std::is_same_v) { if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< T, WeightType, EpilogueTag, FUSION>()) { switch (gemm_config.tile_config_sm100) { - SHAPE_CASE(100, 64, 64, 128) - SHAPE_CASE(100, 64, 128, 128) - SHAPE_CASE(100, 64, 256, 128) + SHAPE_CASE(103, 128, 128, 128) + SHAPE_CASE(103, 128, 256, 128) - SHAPE_CASE(100, 128, 16, 128) - SHAPE_CASE(100, 128, 32, 128) - SHAPE_CASE(100, 128, 64, 128) - SHAPE_CASE(100, 128, 128, 128) - SHAPE_CASE(100, 128, 256, 128) - - DEFAULT_CASE(100) + DEFAULT_CASE(100) // 100 because we use the same member variable for SM100 and SM103 } } else { - TLLM_THROW("Unsupported SM110 configuration requested"); + TLLM_THROW("Unsupported SM103 configuration requested"); } - } else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 110) { + } +#endif + else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 120) { if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< T, WeightType, EpilogueTag, FUSION>()) { switch (gemm_config.tile_config_sm100) { + SHAPE_CASE(100, 64, 32, 128) SHAPE_CASE(100, 64, 64, 128) SHAPE_CASE(100, 64, 128, 128) SHAPE_CASE(100, 64, 256, 128) @@ -360,13 +462,8 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( SHAPE_CASE(100, 128, 128, 128) SHAPE_CASE(100, 128, 256, 128) - SHAPE_CASE(100, 256, 64, 128) - SHAPE_CASE(100, 256, 128, 128) - SHAPE_CASE(100, 256, 256, 128) - // SHAPE_CASE(100, 128, 128, 64) // SHAPE_CASE(100, 128, 256, 64) - // SHAPE_CASE(100, 256, 256, 64) DEFAULT_CASE(100) } } else { @@ -404,4 +501,4 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecialized( return count; } -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h index 3375a60716..eaaedf4258 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h @@ -57,8 +57,10 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput; +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; @@ -69,6 +71,7 @@ template inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); #ifdef COMPILE_HOPPER_TMA_GROUPED_GEMMS switch (inputs.gemm_config.mainloop_schedule) { case tkc::MainloopScheduleType::COOPERATIVE: @@ -120,6 +123,7 @@ template inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); switch (inputs.gemm_config.cluster_shape) { case tkc::ClusterShape::ClusterShape_1x1x1: sm90_dispatch_mainloop_schedules inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually // perform the best for mixed type gemms. @@ -164,11 +169,12 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( #else constexpr int Ntile = 128; constexpr int Ktile = 128 * PackedScalesNum / sizeof(T); - TLLM_CHECK(sizeof(T) == 2); + TLLM_CHECK(sizeof(T) == 1); #endif using _Ntile = Int; using _Ktile = Int; + switch (inputs.gemm_config.tile_config_sm90) { case tkc::CutlassTileConfigSM90::CtaShape64x16x128B: sm90_dispatch_moe_mixed_dtype_gemm_config size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_count_) { size_t count = 0; -#ifdef ENABLE_FP4 +#if defined(ENABLE_FP4) constexpr int Ktile = (std::is_same_v) ? 256 : 512; #else constexpr int Ktile = 512; @@ -267,4 +273,4 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_ return count; } -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu index f240680c6b..52cd03887b 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu @@ -14,43 +14,45 @@ * limitations under the License. */ +#include "../include/moe_gemm_kernels.h" #include "cute/tensor.hpp" #include "cutlass/conv/convolution.h" #include "cutlass/cutlass.h" -#include "moe_gemm_kernels.h" // Order matters here, packed_stride.hpp is missing cute and convolution includes #include "cutlass/util/packed_stride.hpp" #include "tensorrt_llm/common/logger.h" namespace tensorrt_llm::kernels::cutlass_kernels { -std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( +std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( int num_experts, FpXBlockScalingType scaling_type) { size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; - size_t stride_a_size = sizeof(StrideA) * num_experts; - size_t stride_b_size = sizeof(StrideB) * num_experts; - size_t stride_c_size = sizeof(StrideC) * num_experts; - size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; + size_t stride_act_size = std::max(sizeof(StrideA), sizeof(StrideB)) * num_experts; + size_t stride_weight_size = std::max(sizeof(StrideA), sizeof(StrideB)) * num_experts; + size_t stride_c_size = std::max(sizeof(StrideC), sizeof(StrideC_T)) * num_experts; + size_t stride_d_size = std::max(sizeof(StrideD), sizeof(StrideD_T)) * num_experts; size_t ptr_buf_size = sizeof(void*) * num_experts; size_t scale_buf_size = sizeof(float*) * num_experts; - size_t sf_a_size = sizeof(ElementSF*) * num_experts; - size_t sf_b_size = sizeof(ElementSF*) * num_experts; - size_t stride_sf_a_size = scaling_type == FpXBlockScalingType::MXFPX - ? sizeof(MXFPXBlockScaledConfig::LayoutSF) * num_experts - : sizeof(NVFP4BlockScaledConfig::LayoutSF) * num_experts; - size_t stride_sf_b_size = scaling_type == FpXBlockScalingType::MXFPX - ? sizeof(MXFPXBlockScaledConfig::LayoutSF) * num_experts - : sizeof(NVFP4BlockScaledConfig::LayoutSF) * num_experts; + size_t sf_act_size = sizeof(ElementSF*) * num_experts; + size_t sf_weight_size = sizeof(ElementSF*) * num_experts; + size_t stride_sf_act_size = scaling_type == FpXBlockScalingType::MXFPX + ? sizeof(MXFPXBlockScaledConfig::LayoutSF) * num_experts + : sizeof(NVFP4BlockScaledConfig::LayoutSF) * num_experts; + size_t stride_sf_weight_size = scaling_type == FpXBlockScalingType::MXFPX + ? sizeof(MXFPXBlockScaledConfig::LayoutSF) * num_experts + : sizeof(NVFP4BlockScaledConfig::LayoutSF) * num_experts; size_t int4_groupwise_problem_shape_size = sizeof(INT4GroupwiseParams::ProblemShapeInt::UnderlyingProblemShape) * num_experts; size_t int4_groupwise_sf_a_size = sizeof(INT4GroupwiseParams::SFA*) * num_experts; size_t int4_groupwise_stride_sf_a_size = sizeof(INT4GroupwiseParams::StrideSFA) * num_experts; + size_t ptr_token_map_size = sizeof(int**) * num_experts; + return std::array{problem_shape_size, - stride_a_size, - stride_b_size, + stride_act_size, + stride_weight_size, stride_c_size, stride_d_size, ptr_buf_size, @@ -58,13 +60,16 @@ std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( ptr_buf_size, ptr_buf_size, scale_buf_size, - sf_a_size, - sf_b_size, - stride_sf_a_size, - stride_sf_b_size, + sf_act_size, + sf_weight_size, + stride_sf_act_size, + stride_sf_weight_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, - int4_groupwise_stride_sf_a_size}; + int4_groupwise_stride_sf_a_size, + ptr_buf_size, + scale_buf_size, + ptr_token_map_size}; } size_t TmaWarpSpecializedGroupedGemmInput::workspaceSize(int num_experts, @@ -78,7 +83,7 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i size_t gemm_workspace_size, FpXBlockScalingType scaling_type) { auto buffers = workspaceBuffers(num_experts, scaling_type); - std::array pointers{}; + std::array pointers{}; TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); for (int i = 0; i < buffers.size(); i++) { @@ -89,23 +94,23 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i shape_info.num_groups = num_experts; shape_info.problem_shapes = reinterpret_cast(pointers[0]); shape_info.host_problem_shapes = nullptr; - stride_a = reinterpret_cast(pointers[1]); - stride_b = reinterpret_cast(pointers[2]); - stride_c = reinterpret_cast(pointers[3]); - default_epilogue.stride_d = reinterpret_cast(pointers[4]); + stride_act = reinterpret_cast(pointers[1]); + stride_weight = reinterpret_cast(pointers[2]); + stride_c = reinterpret_cast(pointers[3]); + stride_d = reinterpret_cast(pointers[4]); - ptr_a = reinterpret_cast(pointers[5]); - ptr_b = reinterpret_cast(pointers[6]); + ptr_act = reinterpret_cast(pointers[5]); + ptr_weight = reinterpret_cast(pointers[6]); ptr_c = reinterpret_cast(pointers[7]); - default_epilogue.ptr_d = reinterpret_cast(pointers[8]); + ptr_d = reinterpret_cast(pointers[8]); alpha_scale_ptr_array = reinterpret_cast(pointers[9]); - fpX_block_scaling_factors_A = reinterpret_cast(pointers[10]); - fpX_block_scaling_factors_B = reinterpret_cast(pointers[11]); + fpX_block_scaling_factors_act = reinterpret_cast(pointers[10]); + fpX_block_scaling_factors_weight = reinterpret_cast(pointers[11]); - fpX_block_scaling_factors_stride_A = pointers[12]; - fpX_block_scaling_factors_stride_B = pointers[13]; + fpX_block_scaling_factors_stride_act = pointers[12]; + fpX_block_scaling_factors_stride_weight = pointers[13]; int4_groupwise_params.shape.problem_shapes = reinterpret_cast(pointers[14]); @@ -114,27 +119,30 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i int4_groupwise_params.stride_s_a = reinterpret_cast(pointers[16]); + fused_finalize_epilogue.ptr_bias = reinterpret_cast(pointers[17]); + fused_finalize_epilogue.ptr_router_scales = reinterpret_cast(pointers[18]); + fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast(pointers[19]); + this->gemm_workspace = reinterpret_cast(gemm_workspace); this->gemm_workspace_size = gemm_workspace_size; } -void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams( - void* final_output, float const* router_scales, int64_t const* expert_first_token_offset, - int const* source_token_index, void const* bias, int hidden_size, int num_output_tokens) { +void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams(void* final_output, + int hidden_size, + int num_output_tokens, + bool use_reduction) { fused_finalize_epilogue.ptr_final_output = final_output; - fused_finalize_epilogue.ptr_router_scales = router_scales; - fused_finalize_epilogue.ptr_bias = bias; - fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; - fused_finalize_epilogue.ptr_source_token_index = source_token_index; - - fused_finalize_epilogue.stride_final_output = cutlass::make_cute_packed_stride( - FusedFinalizeEpilogue::StrideFinalOutput{}, - transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); - fused_finalize_epilogue.stride_bias = - transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); - fused_finalize_epilogue.stride_router_scales = {}; + + fused_finalize_epilogue.stride_final_output = + cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, + cute::make_shape(num_output_tokens, hidden_size, 1)); + fused_finalize_epilogue.stride_final_output_transposed = + cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput_T{}, + cute::make_shape(hidden_size, num_output_tokens, 1)); fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; + fused_finalize_epilogue.shape_override = hidden_size; + fused_finalize_epilogue.use_reduction = use_reduction; } std::string TmaWarpSpecializedGroupedGemmInput::toString() const { @@ -142,32 +150,29 @@ std::string TmaWarpSpecializedGroupedGemmInput::toString() const { ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n"; if (isValid()) { using PrintType = void const*; - ss << "Ptr A: " << (PrintType)ptr_a << " with Stride: " << (PrintType)stride_a << ",\n" - << "Ptr B: " << (PrintType)ptr_b << " with Stride: " << (PrintType)stride_b << ",\n" + ss << "Ptr Act: " << (PrintType)ptr_act << " with Stride: " << (PrintType)stride_act << ",\n" + << "Ptr Weight: " << (PrintType)ptr_weight << " with Stride: " << (PrintType)stride_weight + << ",\n" << "Ptr C: " << (PrintType)ptr_c << " with Stride: " << (PrintType)stride_c << "\n"; ss << "Epilogue Fusion: " << (int)fusion << ",\n"; if (fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) { ss << "Final Output: " << (PrintType)fused_finalize_epilogue.ptr_final_output; ss << " with Stride: " << fused_finalize_epilogue.stride_final_output; ss << ",\nBias: " << (PrintType)fused_finalize_epilogue.ptr_bias; - ss << " with Stride: " << fused_finalize_epilogue.stride_bias; ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; - ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; - ss << ",\nExpert Offset: " - << (PrintType)fused_finalize_epilogue.ptr_expert_first_token_offset; ss << ", Source Map: " << (PrintType)fused_finalize_epilogue.ptr_source_token_index; } else { - ss << "Ptr D: " << (PrintType)default_epilogue.ptr_d; - ss << " with Stride: " << (PrintType)default_epilogue.stride_d; + ss << "Ptr D: " << (PrintType)ptr_d; + ss << " with Stride: " << (PrintType)stride_d; } ss << '\n'; ss << "Alpha scale ptr: " << (PrintType)alpha_scale_ptr_array << "\n"; ss << "FpX Block Scaling Type: " << (int)fpX_block_scaling_type << "\n"; - ss << "Fp4 Block Scaling Factors A: " << (PrintType)fpX_block_scaling_factors_A - << ", with Stride: " << (PrintType)fpX_block_scaling_factors_stride_A << "\n"; - ss << "Fp4 Block Scaling Factors B: " << (PrintType)fpX_block_scaling_factors_B - << ", with Stride: " << (PrintType)fpX_block_scaling_factors_stride_B << "\n"; + ss << "Fp4 Block Scaling Factors Act: " << (PrintType)fpX_block_scaling_factors_act + << ", with Stride: " << (PrintType)fpX_block_scaling_factors_stride_act << "\n"; + ss << "Fp4 Block Scaling Factors Weight: " << (PrintType)fpX_block_scaling_factors_weight + << ", with Stride: " << (PrintType)fpX_block_scaling_factors_stride_weight << "\n"; ss << "Gemm Workspace: " << (PrintType)gemm_workspace << ", with Size: " << gemm_workspace_size << "\n"; } diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h index d0bcbb978d..fb9ae80f2f 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h @@ -16,9 +16,9 @@ #pragma once +#include "../include/moe_gemm_kernels.h" #include "cutlass/arch/mma_sm90.h" #include "cutlass_extensions/epilogue_helpers.h" -#include "moe_gemm_kernels.h" #ifdef ENABLE_FP4 #include @@ -32,12 +32,16 @@ template constexpr bool isValidSM120MOESpecialisation() { -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) && \ - defined(ENABLE_FP4) // TODO Is there a better choice - return cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice +#if defined(ENABLE_FP4) + return ((cutlass::platform::is_same::value && + cutlass::platform::is_same::value) || + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value)) && + cutlass::platform::is_same::value; +#else + return false; +#endif #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif @@ -49,6 +53,7 @@ template constexpr bool isValidBlackwellMOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice +#if defined(ENABLE_FP4) return (cutlass::platform::is_same::value || #if defined(ENABLE_FP4) (cutlass::platform::is_same::value && @@ -57,8 +62,11 @@ constexpr bool isValidBlackwellMOESpecialisation() { false #endif ) && - cutlass::platform::is_same::value && - Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + cutlass::platform::is_same::value; +#else + return cutlass::platform::is_same::value && + cutlass::platform::is_same::value; +#endif #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif @@ -73,15 +81,12 @@ constexpr bool isValidHopperMOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) return (cutlass::platform::is_same::value || (cutlass::platform::is_same::value && - cutlass::platform::is_same::value) || + cutlass::platform::is_same::value) #ifdef ENABLE_FP4 - (cutlass::platform::is_same<__nv_fp4_e2m1, WeightType>::value && - !cutlass::platform::is_same::value) -#else - false + || (cutlass::platform::is_same<__nv_fp4_e2m1, WeightType>::value && + !cutlass::platform::is_same::value) #endif ) - #ifdef ENABLE_FP4 && !cutlass::platform::is_same::value #endif @@ -98,7 +103,8 @@ template constexpr bool isValidTmaWarpSpecializedMOESpecialisation() { // Check at least one of the implementations are valid - return isValidBlackwellMOESpecialisation() || + return isValidSM120MOESpecialisation() || + isValidBlackwellMOESpecialisation() || isValidHopperMOESpecialisation(); } diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 3ea148c780..516e05c8fc 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -350,6 +350,8 @@ def __init__( use_mxfp8_act_scaling, ) self.activation_type = activation_type + # Set by tuning flow to indicate which GEMM stage (1 or 2) to filter tactics for + self.gemm_idx_for_tuning: Optional[int] = None if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[instance_key] = module.init( @@ -368,7 +370,20 @@ def get_valid_tactics( inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - return list(range(self.fused_moe_runner.get_tactic_num())) + # Prefer filtering tactics by GEMM stage to avoid invalid combos during tuning + try: + gemm1_count = self.fused_moe_runner.get_gemm1_tactic_count() + gemm2_count = self.fused_moe_runner.get_gemm2_tactic_count() + total = gemm1_count + gemm2_count + except Exception: + return list(range(self.fused_moe_runner.get_tactic_num())) + + stage = getattr(self, "gemm_idx_for_tuning", None) + if stage == 1: + return list(range(gemm1_count)) + if stage == 2: + return list(range(gemm1_count, gemm1_count + gemm2_count)) + return list(range(total)) def forward( self, @@ -480,6 +495,8 @@ def cutlass_fused_moe( activation_type=activation_type, ) + # Limit tactics to GEMM1 during tuning + moe_runner.gemm_idx_for_tuning = 1 _, gemm_tactic_1 = tuner.choose_one( "trtllm::fused_moe::gemm1", [moe_runner], @@ -494,6 +511,8 @@ def cutlass_fused_moe( gemm_idx=1, ) + # Limit tactics to GEMM2 during tuning + moe_runner.gemm_idx_for_tuning = 2 _, gemm_tactic_2 = tuner.choose_one( "trtllm::fused_moe::gemm2", [moe_runner], diff --git a/flashinfer/jit/gemm/cutlass/generate_kernels.py b/flashinfer/jit/gemm/cutlass/generate_kernels.py index e767361a65..f7a87bedbd 100644 --- a/flashinfer/jit/gemm/cutlass/generate_kernels.py +++ b/flashinfer/jit/gemm/cutlass/generate_kernels.py @@ -144,6 +144,8 @@ def __init__( epi_schedule, epi_fusion=None, is_mx_fpx=False, + dynamic_cga=False, + swap_ab=False, ): self.gemm_kind = gemm_kind self.arch = arch @@ -158,10 +160,12 @@ def __init__( self.warp_shape = warp_shape self.stages = stages self.cga_shape = cga_shape + self.dynamic_cga = dynamic_cga self.mainloop_schedule = mainloop_schedule self.epi_schedule = epi_schedule self.epi_fusion = epi_fusion self.is_mx_fpx = is_mx_fpx + self.swap_ab = swap_ab def __repr__(self): kernel_prefix = "{}_sm{}_{}_{}_{}_{}_{}_{}_{}_{}x{}x{}_{}x{}x{}_{}".format( @@ -183,13 +187,15 @@ def __repr__(self): self.stages, ) - hopper_suffix = "_{}x{}x{}{}{}{}".format( + hopper_suffix = "_{}x{}x{}{}{}{}{}{}".format( self.cga_shape[0], self.cga_shape[1], self.cga_shape[2], KernelScheduleSuffixes[self.mainloop_schedule], EpilogueScheduleSuffixes[self.epi_schedule], EpiFusionSuffixes[self.epi_fusion], + "_mxfpx_" if self.is_mx_fpx else "", + "_swap_ab" if self.swap_ab else "", ) if self.arch >= 90: @@ -217,7 +223,9 @@ def instantiate_operation_tma_warp_specialized(operation): cute_cga_shape = tuple_to_cute_shape(operation.cga_shape) kernel_sched = KernelScheduleTag[operation.mainloop_schedule] - epi_sched = EpilogueScheduleTag[operation.epi_schedule] + epi_sched = "void" + if operation.epi_schedule is not None: + epi_sched = EpilogueScheduleTag[operation.epi_schedule] if operation.gemm_kind == GemmKind.Gemm: weight_tag = DataTypeTag[operation.weight_type] @@ -228,8 +236,7 @@ def instantiate_operation_tma_warp_specialized(operation): {kernel_sched}, {epi_sched}> ( const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float, {out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* -); -""" +);""" elif operation.gemm_kind == GemmKind.Grouped: if operation.act_type != operation.weight_type and ( operation.act_type != DataType.e4m3 or operation.weight_type != e2m1 @@ -247,18 +254,21 @@ def instantiate_operation_tma_warp_specialized(operation): KernelScheduleType.TmaWarpSpecializedCooperative, KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, ] - assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized kernel_sched.replace("::Kernel", "::KernelGrouped") - epi_sched += "Grouped" - + # epi_sched += "Grouped" # arch_tag = f"cutlass::arch::Sm{operation.arch}" arch_tag = f"Sm{operation.arch}" weight_tag = CudaTypeName[operation.weight_type] assert operation.epi_fusion is not None epi_fusion = EpiFusion[operation.epi_fusion] + # We need to remove the '::' because this will break the instantiation macro epi_fusion = epi_fusion.split(":")[-1] epi_tag = epi_tag.split(":")[-1] + epi_sched = epi_sched.split(":")[-1] + epi_sched = epi_sched.replace( + "1Sm", "" + ) # Hack to WAR missing `PtrArrayTmaWarpSpecialized` type guard_map = { e2m1: "defined(ENABLE_FP4)", @@ -267,6 +277,11 @@ def instantiate_operation_tma_warp_specialized(operation): } guard_act = guard_map.get(operation.act_type, "1") guard_weight = guard_map.get(operation.weight_type, "1") + + is_mx_fpx = str(operation.is_mx_fpx).lower() + use_dynamic_cga = str(operation.dynamic_cga).lower() + use_bias = str(False).lower() + swap_ab = str(operation.swap_ab).lower() # TODO Revert this once compiler bug is fixed so we can use template instead of macro again # instantiation = f""" # template void tma_warp_specialized_generic_moe_gemm_kernelLauncher<{arch_tag}, {act_tag}, {weight_tag}, {out_tag}, @@ -274,11 +289,12 @@ def instantiate_operation_tma_warp_specialized(operation): # (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*); # """ instantiation = f""" -#if {guard_act} && {guard_weight}\n +#if {guard_act} && {guard_weight} INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, - {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);\n -#endif -""" + {epi_sched}, {epi_tag}, {epi_fusion}, + {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, + {is_mx_fpx}, {use_dynamic_cga}, {use_bias}, {swap_ab}); +#endif""" return instantiation @@ -289,8 +305,7 @@ def instantiate_operation_sm80(operation): instantiation = f""" template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}> - ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy); - """ + ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);""" return instantiation @@ -318,12 +333,12 @@ def get_file_content(launcher_inl_files, operations): {{ namespace kernels {{ -namespace cutlass_kernels +namespace cutlass_kernels_oss {{ {instantiations} -}} // namespace cutlass_kernels +}} // namespace cutlass_kernels_oss }} // namespace kernels }} // namespace tensorrt_llm """ @@ -353,17 +368,28 @@ def write_file(launcher_inl_files, operations, output_file): f.write(content) -from operator import mul, truediv - +def is_gemm_op_valid_sm100(op): + # TODO These are much more restricted than theory dictates, investigate if more can be enabled in future + tile_m, tile_n, _ = op.cta_shape + cga_m, cga_n, cga_k = op.cga_shape -def elementwise(x, y, f): - return tuple(f(a, b) for (a, b) in zip(x, y)) + if ( + op.epi_fusion == TrtLlm_EpilogueFusion.epilogue_fusion_finalize + and op.epi_schedule != EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm + ): + return False + # We use a runtime cluster shape for SM100, so we only use cluster shapes to distinguish between 1SM and 2SM variants. + if cga_m > 2 or cga_n != 1 or cga_k != 1: + return False -def is_gemm_op_valid_sm100(op): - # TODO These are much more restricted than theory dictates, investigate if more can be enabled in future - tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv) - cga_m, cga_n, _ = op.cga_shape + if op.arch == 103: + return ( + op.act_type == e2m1 + and op.weight_type == e2m1 + and tile_m == 128 + and tile_n in [128, 256] + ) # Default shapes # This is epilogue tile size. For two CTA this is actually size 128/256 for the MMA @@ -372,23 +398,23 @@ def is_gemm_op_valid_sm100(op): # FP4 Has some much more limited sizes if op.act_type == e2m1 or op.weight_type == e2m1: - # TODO 128x256x256 FP4 compiles but crashes - # if tile_n % 64 != 0 or tile_n < 128: - # return False if tile_n not in [64, 128, 256] or tile_m != 128: return False + # TODO Revert this once cutlass adds support for blockscaled + no smem + if ( + op.arch == 100 + and op.epi_schedule == EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm + ): + return False # Shapes for fp8 small N shapes if ( - op.act_type == DataType.e4m3 + (op.act_type == DataType.e4m3) and (tile_n == 16 or tile_n == 8) and (cga_m == 1 and cga_n == 1) ): - # todo: double check why this is disable in CUTLASS backend. @yuhan - if tile_m == 128 and tile_n == 8: - return False - else: - return True + # todo: double check why tile_n = 8 is disabled in CUTLASS backend. @yuhan + return tile_m != 128 or tile_n % 16 == 0 # Default alignment requirements if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256: @@ -427,7 +453,10 @@ def is_grouped_gemm_op_valid(op): if op.epi_tag != TrtLlm_EpilogueTag.epilogue_op_default: return False - if op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized: + if ( + op.epi_schedule is not None + and op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized + ): return False if op.mainloop_schedule not in [ @@ -543,14 +572,30 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled): TrtLlm_EpilogueFusion.epilogue_fusion_finalize, ] + swap_ab = [True, False] + cga_shapes = product([1, 2], [1, 2], [1]) partial_args = product( - supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes + supported_dtypes, + quant_ops, + epi_tags, + epi_fusions, + cta_shapes_mn, + cga_shapes, + swap_ab, ) operations = list() - for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args: + for ( + dtype, + quant_op, + epi_tag, + epi_fusion, + cta_shape_mn, + cga_shape, + swap_ab, + ) in partial_args: max_k_bits = 128 * 8 cta_shape_k = max_k_bits // GetDataTypeBits(dtype) cta_shape_mnk = cta_shape_mn + (cta_shape_k,) @@ -560,7 +605,7 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled): if dtype != DataType.e4m3 else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum ) - epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized + epi_schedule = None otypes = [dtype] if dtype == DataType.e4m3: @@ -584,6 +629,7 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled): mainloop_schedule, epi_schedule, epi_fusion, + swap_ab=swap_ab, ) if is_op_valid(moe_gemm_operation): @@ -693,8 +739,6 @@ def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype): cta_shape_k = max_k_bits // GetDataTypeBits(dtype) if dtype == DataType.e4m3 and (cta_shape_mn[1] == 8): cta_shape_k = 256 - if dtype == DataType.e4m3 and (cta_shape_mn[1] == 16): - cta_shape_k = 128 return cta_shape_mn + (cta_shape_k,) @@ -702,7 +746,7 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): if not is_arch_enabled: return [] arch = 120 - supported_dtypes = [e2m1] + supported_dtypes = [e2m1, (DataType.e4m3, e2m1)] quant_ops = [TrtLlm_QuantOp.none] epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default] cta_shapes_mnk = [ @@ -717,45 +761,71 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize + TrtLlm_EpilogueFusion.epilogue_fusion_finalize, ] cga_shapes = [[1, 1, 1]] + swap_ab = [True, False] + partial_args = product( - supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mnk, cga_shapes + supported_dtypes, + quant_ops, + epi_tags, + epi_fusions, + cta_shapes_mnk, + cga_shapes, + swap_ab, ) operations = list() - for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args: - cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) - + for ( + dtype, + quant_op, + epi_tag, + epi_fusion, + cta_shape_mnk, + cga_shape, + swap_ab, + ) in partial_args: # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative - epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized + epi_schedule = None - otypes = [dtype] - if dtype in [DataType.e4m3, e2m1]: + if isinstance(dtype, tuple): + act_type, weight_type = dtype + else: + act_type, weight_type = dtype, dtype + + # Minimal filter: for mixed FP8xFP4 on SM120, only emit 128x128x128 + if act_type == DataType.e4m3 and weight_type == e2m1: + if cta_shape_mnk != [128, 128, 128]: + continue + + otypes = [act_type] + if act_type in [DataType.e4m3, e2m1]: otypes = [DataType.f16, DataType.bf16] for otype in otypes: moe_gemm_operation = TrtLlm_GemmLauncher( GemmKind.Grouped, arch, - dtype, - dtype, - dtype, - dtype, + act_type, + weight_type, + act_type, + act_type, otype, quant_op, epi_tag, - cga_tile_shape_mnk, + cta_shape_mnk, warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule, epi_fusion, + is_mx_fpx=(act_type == DataType.e4m3 and weight_type == e2m1), + swap_ab=swap_ab, ) operations.append(moe_gemm_operation) @@ -767,10 +837,9 @@ def generate_sm120_operations(is_arch_enabled): return operations -def generate_sm100_grouped_gemm_operations(is_arch_enabled): +def generate_sm100_grouped_gemm_operations(is_arch_enabled, arch): if not is_arch_enabled: return [] - arch = 100 supported_dtypes = [ DataType.f16, DataType.bf16, @@ -782,7 +851,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): quant_ops = [TrtLlm_QuantOp.none] epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default] cta_shapes_m = [64, 128] - cta_shapes_n = [8, 16, 32, 64, 128, 256] + cta_shapes_n = [8, 16, 32, 64, 128, 192, 256] cta_shapes_mn = product(cta_shapes_m, cta_shapes_n) warp_shape = [0, 0, 0] # ignored except for naming @@ -790,28 +859,55 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize + TrtLlm_EpilogueFusion.epilogue_fusion_finalize, ] - cga_shapes = list(product([1, 2], [1, 2], [1])) + # Some shapes for SM100 are better with NoSmem, note the kernel will internally map to the 1 or 2 SM variants based on the cga_shape[0] + epi_schedules = [ + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + ] + + # We will use dynamic cluster shapes for SM100, so we only need to indicate if we are using 1 or 2 SM version + cga_shapes = [(1, 1, 1), (2, 1, 1)] + + swap_ab = [True, False] + + dynamic_cga = [True, False] partial_args = product( - supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes + supported_dtypes, + quant_ops, + epi_tags, + epi_fusions, + cta_shapes_mn, + cga_shapes, + epi_schedules, + dynamic_cga, + swap_ab, ) operations = list() - for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args: + for ( + dtype, + quant_op, + epi_tag, + epi_fusion, + cta_shape_mn, + cga_shape, + epi_schedule, + dynamic_cga, + swap_ab, + ) in partial_args: if isinstance(dtype, tuple): dtype, weight_type = dtype else: weight_type = dtype cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype) - cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative - epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized otypes = [dtype] if dtype in [DataType.e4m3, e2m1]: @@ -828,7 +924,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): otype, quant_op, epi_tag, - cga_tile_shape_mnk, + cta_shape_mnk, warp_shape, stages, cga_shape, @@ -836,6 +932,8 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): epi_schedule, epi_fusion, is_mx_fpx=(dtype == DataType.e4m3 and weight_type == e2m1), + dynamic_cga=dynamic_cga, + swap_ab=swap_ab, ) if is_op_valid(moe_gemm_operation): @@ -843,8 +941,13 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): return operations +def generate_sm103_operations(is_arch_enabled): + operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 103) + return operations + + def generate_sm100_operations(is_arch_enabled): - operations = generate_sm100_grouped_gemm_operations(is_arch_enabled) + operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 100) return operations @@ -908,18 +1011,25 @@ def generate_gemm_operations(output_dir, architectures): (GemmKind.Gemm, 90): [fpA_intB_inl], (GemmKind.Grouped, 90): [moe_gemm_inl], (GemmKind.Grouped, 100): [moe_gemm_inl], + (GemmKind.Grouped, 103): [moe_gemm_inl], (GemmKind.Grouped, 120): [moe_gemm_inl], (GemmKind.Grouped, 80): [sm80_moe_gemm_inl], } def has_arch(sm): - return f"{sm}" in arches or f"{sm}-real" in arches + return ( + f"{sm}" in arches + or f"{sm}-real" in arches + or f"{sm}f-real" in arches + or f"{sm}f" in arches + ) # The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads. # Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve. operations = [] operations += generate_sm120_operations(has_arch(120) or has_arch(121)) - operations += generate_sm100_operations(has_arch(100)) + operations += generate_sm103_operations(has_arch(103)) + operations += generate_sm100_operations(has_arch(100) or has_arch(103)) operations += generate_sm90_operations(has_arch(90)) operations += generate_sm80_operations(has_arch(80) or has_arch(89)) diff --git a/tests/moe/test_trtllm_cutlass_fused_moe.py b/tests/moe/test_trtllm_cutlass_fused_moe.py index bae12ab070..acca53f9a0 100644 --- a/tests/moe/test_trtllm_cutlass_fused_moe.py +++ b/tests/moe/test_trtllm_cutlass_fused_moe.py @@ -1109,8 +1109,8 @@ def dequant_mxfp4_batches( ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] ) @pytest.mark.skipif( - torch.cuda.get_device_capability()[0] not in [10, 11], - reason="MXFP8xMXFP4 is only supported on SM100 and SM110", + torch.cuda.get_device_capability()[0] not in [10, 11, 12], + reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120", ) def test_moe_mxfp8_mxfp4( batch_size, From f588d96af1d27bd8714b63e40ccc2955d92a5593 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Fri, 7 Nov 2025 11:50:59 -0800 Subject: [PATCH 038/130] perf: Optimize helper max/minmax function in sampling.cuh (#2058) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Apply optimizations similar to #2044 to max/min functions. ## ๐Ÿ” Related Issues #2044 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Performance Improvements** * Improved sampling performance by reducing per-iteration synchronization and temporary storage, deferring aggregate reductions until after iterative work completes. This lowers runtime overhead and memory churn, yielding faster and more efficient processing for sampling operations. Co-authored-by: yzh119 --- include/flashinfer/sampling.cuh | 43 ++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index f3b188abec..03d4bfa8e2 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -249,27 +249,31 @@ __device__ __forceinline__ std::tuple GetMinMaxValue(float* in_dat TempStorage& temp_storage) { const uint32_t tx = threadIdx.x; vec_t in_data_vec; - float max_val = -cuda::std::numeric_limits::infinity(), - min_val = cuda::std::numeric_limits::infinity(); + // Thread-local min/max accumulation (deferred reduction) + float thread_max = -cuda::std::numeric_limits::infinity(); + float thread_min = cuda::std::numeric_limits::infinity(); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { in_data_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } - float in_data_[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - in_data_[j] = in_data_vec[j]; + thread_max = max(thread_max, static_cast(in_data_vec[j])); + thread_min = min(thread_min, static_cast(in_data_vec[j])); } - max_val = max( - max_val, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(in_data_, MaxReduceOp{})); - __syncthreads(); - min_val = min( - min_val, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(in_data_, MinReduceOp{})); - __syncthreads(); } + + // Single block reduction after loop completes + float max_val = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(thread_max, MaxReduceOp{}); + __syncthreads(); + float min_val = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(thread_min, MinReduceOp{}); + if (tx == 0) { temp_storage.max_val = max_val; temp_storage.min_val = min_val; @@ -288,22 +292,23 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u const uint32_t tx = threadIdx.x; vec_t in_data_vec; - float max_val = 0; + // Thread-local max accumulation (deferred reduction) + float thread_max = 0.0f; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { in_data_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - float in_data_[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - in_data_[j] = in_data_vec[j]; + thread_max = max(thread_max, static_cast(in_data_vec[j])); } - max_val = max( - max_val, BlockReduce(temp_storage.block_prim.reduce) - .template Reduce(in_data_, MaxReduceOp{})); - __syncthreads(); } + + // Single block reduction after loop completes + float max_val = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(thread_max, MaxReduceOp{}); if (tx == 0) { temp_storage.max_val = max_val; } From c8f2b03dda635b97f7ea9ad9d6ec73eb96a153b2 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe <50598321+nvmbreughe@users.noreply.github.com> Date: Fri, 7 Nov 2025 16:52:59 -0600 Subject: [PATCH 039/130] [DSV3] Optimized Router Gemm (#2019) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This PR: * adds an optimized router gemm for problem sizes such as Deep Seek-V3. It is ported over from TRTLLM. * serves as an example on API naming for specialized ops on narrow support surfaces From my measurements (num tokens = [1,2,4,8,16]), speedups were observed between 1.36 and 1.82x on B200. Both positive and negative tests were added to test the behavior. ## Breaking Change: Refactored gemm module structure **ACTION REQUIRED:** Delete stale `flashinfer/gemm.py` file The `gemm.py` file has been refactored into a package: - `flashinfer/gemm.py` โ†’ `flashinfer/gemm/gemm_base.py` After pulling this change, run: ```bash git clean -fd flashinfer/ # OR manually: rm flashinfer/flashinfer/gemm.py ``` This is backward compatible - no import changes needed. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * High-performance DSv3 router GEMM (bf16 โ†’ float32) optimized for tokens 1โ€“16, 256 experts, 7168 hidden dim with optional serialized launch. * **Integration** * Python wrapper exposing the op with runtime shape/dtype/stride validation and registered custom-op entrypoint. * **JIT / Packaging** * Adds a JIT module generator and re-exports it for easy import. * **Tests** * Unit tests for supported configs and comprehensive validation/error cases. * **Chores** * Import-path cleanup and test-script pre-run bytecode cache cleanup. --------- Co-authored-by: yzh119 --- csrc/dsv3_router_gemm.cu | 152 +++++++++++++++++ flashinfer/dsv3_ops/__init__.py | 5 + flashinfer/gemm/__init__.py | 34 ++++ flashinfer/{gemm.py => gemm/gemm_base.py} | 36 ++-- flashinfer/gemm/routergemm_dsv3.py | 134 +++++++++++++++ flashinfer/jit/__init__.py | 3 + flashinfer/jit/dsv3_optimizations.py | 11 ++ include/flashinfer/gemm/dsv3_router_gemm.cuh | 159 ++++++++++++++++++ scripts/task_test_blackwell_kernels.sh | 7 + ...ask_test_jit_cache_package_build_import.sh | 6 + scripts/task_test_multi_node_comm_kernels.sh | 7 + scripts/task_test_nightly_build.sh | 7 + scripts/task_test_single_node_comm_kernels.sh | 7 + tests/gemm/test_group_gemm.py | 5 +- tests/gemm/test_mm_fp4.py | 2 +- tests/gemm/test_tgv_gemm.py | 2 +- .../test_dsv3_router_gemm.py | 137 +++++++++++++++ 17 files changed, 692 insertions(+), 22 deletions(-) create mode 100644 csrc/dsv3_router_gemm.cu create mode 100644 flashinfer/dsv3_ops/__init__.py create mode 100644 flashinfer/gemm/__init__.py rename flashinfer/{gemm.py => gemm/gemm_base.py} (99%) create mode 100644 flashinfer/gemm/routergemm_dsv3.py create mode 100644 flashinfer/jit/dsv3_optimizations.py create mode 100644 include/flashinfer/gemm/dsv3_router_gemm.cuh create mode 100644 tests/model_optimizations/test_dsv3_router_gemm.py diff --git a/csrc/dsv3_router_gemm.cu b/csrc/dsv3_router_gemm.cu new file mode 100644 index 0000000000..2d44147d97 --- /dev/null +++ b/csrc/dsv3_router_gemm.cu @@ -0,0 +1,152 @@ +#include "flashinfer/gemm/dsv3_router_gemm.cuh" +#include "tvm_ffi_utils.h" + +namespace flashinfer::trtllm_dsv3_router_gemm { +template +void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream, + bool use_pdl = false) { + constexpr int VPT = 16 / sizeof(T); + constexpr int kBlockSize = 128; + cudaLaunchConfig_t config; + config.gridDim = kNumExperts; + config.blockDim = kBlockSize; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = use_pdl; + config.numAttrs = 1; + config.attrs = attrs; + auto status = cudaLaunchKernelEx( + &config, router_gemm_kernel, output, + mat_a, mat_b); + TVM_FFI_ICHECK(status == cudaSuccess) + << "cudaLaunchKernelEx failed with error code " << cudaGetErrorString(status); +} + +template void invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, + __nv_bfloat16 const*, cudaStream_t, + bool); + +template +struct LoopUnroller { + static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) { + if (num_tokens == kBegin) { + invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, + stream, launch_with_pdl); + } else { + LoopUnroller::unroll( + num_tokens, output, input, weights, stream, launch_with_pdl); + } + } +}; + +template +struct LoopUnroller { + static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) { + if (num_tokens == kEnd) { + invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream, + launch_with_pdl); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } +}; + +void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) { + int const num_tokens = mat_a.sizes()[0]; + int const num_experts = mat_b.sizes()[1]; + int const hidden_dim = mat_a.sizes()[1]; + auto const out_dtype_ = out.dtype(); + auto const data_type = mat_a.dtype(); + constexpr int kNumExperts = 256; + constexpr int kHiddenDim = 7168; + std::vector output_size = {mat_a.sizes()[0], mat_b.sizes()[1]}; + TVM_FFI_ICHECK(mat_a.dim() == 2 && mat_b.dim() == 2) << "mat_a and mat_b must be 2D tensors"; + TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1) + << "mat_a and out must be row-major"; + TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major"; + auto stream = get_stream(mat_a.device()); + bool use_custom_kernel = false; + if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts && + hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code && + encode_dlpack_dtype(out_dtype_) == float32_code) { + use_custom_kernel = true; + } + + if (use_custom_kernel) { + LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll( + num_tokens, reinterpret_cast(out.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl); + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input tensor size"; + } +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(dsv3_router_gemm_op, + flashinfer::trtllm_dsv3_router_gemm::dsv3_router_gemm_op); + +} // namespace flashinfer::trtllm_dsv3_router_gemm diff --git a/flashinfer/dsv3_ops/__init__.py b/flashinfer/dsv3_ops/__init__.py new file mode 100644 index 0000000000..49fb43b3ec --- /dev/null +++ b/flashinfer/dsv3_ops/__init__.py @@ -0,0 +1,5 @@ +from flashinfer.gemm import mm_M1_16_K7168_N256 + +__all__ = [ + "mm_M1_16_K7168_N256", +] diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py new file mode 100644 index 0000000000..15652268ba --- /dev/null +++ b/flashinfer/gemm/__init__.py @@ -0,0 +1,34 @@ +from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper +from .gemm_base import bmm_fp8 as bmm_fp8 +from .gemm_base import mm_fp4 as mm_fp4 +from .gemm_base import mm_fp8 as mm_fp8 +from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100 +from .gemm_base import group_gemm_mxfp4_nt_groupwise as group_gemm_mxfp4_nt_groupwise +from .gemm_base import ( + batch_deepgemm_fp8_nt_groupwise as batch_deepgemm_fp8_nt_groupwise, +) +from .gemm_base import ( + group_deepgemm_fp8_nt_groupwise as group_deepgemm_fp8_nt_groupwise, +) +from .gemm_base import gemm_fp8_nt_blockscaled as gemm_fp8_nt_blockscaled +from .gemm_base import gemm_fp8_nt_groupwise as gemm_fp8_nt_groupwise +from .gemm_base import group_gemm_fp8_nt_groupwise as group_gemm_fp8_nt_groupwise + +from .routergemm_dsv3 import ( + mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256, +) + +__all__ = [ + "SegmentGEMMWrapper", + "bmm_fp8", + "mm_fp4", + "mm_fp8", + "tgv_gemm_sm100", + "group_gemm_mxfp4_nt_groupwise", + "batch_deepgemm_fp8_nt_groupwise", + "group_deepgemm_fp8_nt_groupwise", + "gemm_fp8_nt_blockscaled", + "gemm_fp8_nt_groupwise", + "group_gemm_fp8_nt_groupwise", + "mm_M1_16_K7168_N256", +] diff --git a/flashinfer/gemm.py b/flashinfer/gemm/gemm_base.py similarity index 99% rename from flashinfer/gemm.py rename to flashinfer/gemm/gemm_base.py index fc4c1b8885..ac0fbab4a0 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm/gemm_base.py @@ -22,7 +22,7 @@ from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm import torch -from .autotuner import ( +from ..autotuner import ( AutoTuner, ConstraintSpec, DynamicTensorSpec, @@ -30,11 +30,11 @@ TunableRunner, TuningConfig, ) -from .fused_moe.utils import ( +from ..fused_moe.utils import ( get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, ) -from .utils import ( +from ..utils import ( get_native_fp4_dtype, is_sm100a_supported, is_sm100f_supported, @@ -44,16 +44,16 @@ backend_requirement, supported_compute_capability, ) -from .jit.gemm import gen_gemm_sm90_module -from .jit.gemm import gen_gemm_module -from .jit.gemm import gen_gemm_sm100_module -from .jit.gemm import gen_gemm_sm120_module -from .jit.gemm import gen_gemm_sm120_module_cutlass_fp4 -from .jit.gemm import gen_gemm_sm100_module_cutlass_fp4 -from .jit.gemm import gen_gemm_sm100_module_cutlass_fp8 -from .jit.gemm import gen_trtllm_gen_gemm_module -from .jit.gemm import gen_tgv_gemm_sm10x_module -from .jit.gemm import gen_deepgemm_sm100_module +from ..jit.gemm import gen_gemm_sm90_module +from ..jit.gemm import gen_gemm_module +from ..jit.gemm import gen_gemm_sm100_module +from ..jit.gemm import gen_gemm_sm120_module +from ..jit.gemm import gen_gemm_sm120_module_cutlass_fp4 +from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp4 +from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp8 +from ..jit.gemm import gen_trtllm_gen_gemm_module +from ..jit.gemm import gen_tgv_gemm_sm10x_module +from ..jit.gemm import gen_deepgemm_sm100_module CUDNN_AVAILABLE = False @@ -70,8 +70,8 @@ raise -from .jit.cubin_loader import setup_cubin_loader -from .utils import ( +from ..jit.cubin_loader import setup_cubin_loader +from ..utils import ( _get_cache_buf, determine_gemm_backend, get_indptr, @@ -733,7 +733,7 @@ def launch_compute_sm80_group_gemm_args( w_stride_data = torch.empty(batch_size, dtype=ld_type, device=device) y_stride_data = torch.empty(batch_size, dtype=ld_type, device=device) - from .triton.gemm import compute_sm80_group_gemm_args + from ..triton.gemm import compute_sm80_group_gemm_args compute_sm80_group_gemm_args[(batch_size,)]( all_problems, @@ -795,7 +795,7 @@ def launch_compute_sm90_group_gemm_args( w_stride_data = torch.empty(batch_size, dtype=stride_type, device=device) y_stride_data = torch.empty(batch_size, dtype=stride_type, device=device) - from .triton.gemm import compute_sm90_group_gemm_args + from ..triton.gemm import compute_sm90_group_gemm_args compute_sm90_group_gemm_args[(batch_size,)]( all_problems, @@ -2822,7 +2822,7 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( def pad_indptr_to_multiple_of_4( m_indptr: torch.Tensor, ): - from .triton.gemm import compute_padding_mapping + from ..triton.gemm import compute_padding_mapping batch_size = m_indptr.shape[0] - 1 m = m_indptr[1:] - m_indptr[:-1] diff --git a/flashinfer/gemm/routergemm_dsv3.py b/flashinfer/gemm/routergemm_dsv3.py new file mode 100644 index 0000000000..05415ec61f --- /dev/null +++ b/flashinfer/gemm/routergemm_dsv3.py @@ -0,0 +1,134 @@ +from flashinfer.jit import gen_dsv3_router_gemm_module +import functools +from types import SimpleNamespace +import torch +from flashinfer.utils import ( + register_custom_op, + supported_compute_capability, + backend_requirement, +) + + +# TODO: other compute capabilities may be supported but are untested +@supported_compute_capability([100]) +def _mm_M1_16_K7168_N256_shape_checks(mat_a, mat_b, out, launch_with_pdl): + # Dimension checks + if mat_a.dim() != 2: + raise ValueError("mat_a must be a 2D tensor") + if mat_b.dim() != 2: + raise ValueError("mat_b must be a 2D tensor") + if out.dim() != 2: + raise ValueError("out must be a 2D tensor") + + # Stride checks (check these before dimension checks to give better error messages) + if mat_a.stride(1) != 1: + raise ValueError("mat_a must be row-major") + if out.stride(1) != 1: + raise ValueError("out must be row-major") + if mat_b.stride(0) != 1: + raise ValueError("mat_b must be column-major") + + if mat_a.shape[1] != mat_b.shape[0]: + raise ValueError("mat_a.shape[1] must be equal to mat_b.shape[0]") + if out.shape[0] != mat_a.shape[0]: + raise ValueError("out.shape[0] must be equal to mat_a.shape[0]") + if out.shape[1] != mat_b.shape[1]: + raise ValueError("out.shape[1] must be equal to mat_b.shape[1]") + + # Problem size checks + expected_hidden_dim = 7168 + expected_num_experts = 256 + min_tokens = 1 + max_tokens = 16 + if mat_a.shape[0] < min_tokens or mat_a.shape[0] > max_tokens: + raise ValueError( + f"mat_a.shape[0] (num_tokens) must be between {min_tokens} and {max_tokens}" + ) + if mat_a.shape[1] != expected_hidden_dim: + raise ValueError( + f"mat_a.shape[1] (hidden_dim) must be equal to {expected_hidden_dim}" + ) + if mat_b.shape[1] != expected_num_experts: + raise ValueError( + f"mat_b.shape[1] (num_experts) must be equal to {expected_num_experts}" + ) + + # Data type checks + if mat_a.dtype != torch.bfloat16: + raise ValueError("mat_a must be a bfloat16 tensor") + if mat_b.dtype != torch.bfloat16: + raise ValueError("mat_b must be a bfloat16 tensor") + if out.dtype != torch.float32: + raise ValueError("out must be a float32 tensor") + + return True + + +@functools.cache +def get_dsv3_router_gemm_module(): + module = gen_dsv3_router_gemm_module().build_and_load() + + @register_custom_op( + "flashinfer::dsv3_router_gemm_op", + mutates_args=["out"], + ) + def mm_M1_16_K7168_N256( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + out: torch.Tensor, + launch_with_pdl: bool = False, + ) -> None: + module.dsv3_router_gemm_op(mat_a, mat_b, out, launch_with_pdl) + + return SimpleNamespace( + mm_M1_16_K7168_N256=mm_M1_16_K7168_N256, + ) + + +@backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks) +def mm_M1_16_K7168_N256( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + out: torch.Tensor, + launch_with_pdl: bool = False, +) -> None: + """Optimized GEMM for the router operation in DeepSeek-V3. + + This function performs a highly optimized matrix multiplication specifically tailored + for the expert routing GEMM in DeepSeek-V3's Mixture of Experts (MoE) architecture. + It computes out = mat_a @ mat_b where mat_a contains token embeddings and mat_b + contains expert routing weights. + + The implementation is optimized for the specific problem dimensions used in DeepSeek-V3: + - Hidden dimension (K): 7168 + - Number of experts (N): 256 + - Number of tokens (M): 1-16 + + Args: + mat_a (torch.Tensor): Input token embeddings of shape (M, K) where M is the number + of tokens (1-16) and K is the hidden dimension (7168). Must be bfloat16, + row-major (contiguous). + mat_b (torch.Tensor): Expert routing weights of shape (K, N) where K is the hidden + dimension (7168) and N is the number of experts (256). Must be bfloat16, + column-major (transposed layout). + out (torch.Tensor): Pre-allocated output tensor of shape (M, N) containing the + routing scores. Must be float32, row-major (contiguous). This tensor is + mutated in-place. + launch_with_pdl (bool, optional): Whether to launch the kernel using Persistent + Device-side Launch. Defaults to False. + + Returns: + None: The result is written directly to the `out` tensor. + + Raises: + ValueError: If tensor dimensions, strides, or data types do not match the + expected DeepSeek-V3 router configuration. + + Note: + This kernel is specialized for compute capability 10.0 (Blackwell architecture). + The specific problem size optimization makes this significantly faster than + general-purpose GEMM implementations for the router operation. + """ + get_dsv3_router_gemm_module().mm_M1_16_K7168_N256( + mat_a, mat_b, out, launch_with_pdl + ) diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 314dee1eb3..bc4132ec9c 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -76,6 +76,9 @@ from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module from .comm import gen_vllm_comm_module as gen_vllm_comm_module from .comm import gen_nvshmem_module as gen_nvshmem_module +from .dsv3_optimizations import ( + gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module, +) cuda_lib_path = os.environ.get( diff --git a/flashinfer/jit/dsv3_optimizations.py b/flashinfer/jit/dsv3_optimizations.py new file mode 100644 index 0000000000..88be890699 --- /dev/null +++ b/flashinfer/jit/dsv3_optimizations.py @@ -0,0 +1,11 @@ +from .core import JitSpec, gen_jit_spec +from . import env as jit_env + + +def gen_dsv3_router_gemm_module() -> JitSpec: + return gen_jit_spec( + "dsv3_router_gemm", + [ + jit_env.FLASHINFER_CSRC_DIR / "dsv3_router_gemm.cu", + ], + ) diff --git a/include/flashinfer/gemm/dsv3_router_gemm.cuh b/include/flashinfer/gemm/dsv3_router_gemm.cuh new file mode 100644 index 0000000000..aef712d68e --- /dev/null +++ b/include/flashinfer/gemm/dsv3_router_gemm.cuh @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace flashinfer::trtllm_dsv3_router_gemm { +// Custom FMA implementation using PTX assembly instructions +__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) { + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +} + +// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion +template +__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) { + __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); + +#pragma unroll + for (int i = 0; i < VPT; i++) { + dst[i] = __bfloat162float(bf16_ptr[i]); + } +} + +template +__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, + T const* mat_b) { + // Each block handles one expert column + int const n_idx = blockIdx.x; + int const tid = threadIdx.x; + constexpr int kWarpSize = 32; + constexpr int kNumWarps = kBlockSize / kWarpSize; + // Constants for this kernel + constexpr int k_elems_per_k_iteration = VPT * kBlockSize; + constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations + + // Initialize accumulators for all M rows + float acc[kNumTokens] = {}; + + // Shared memory for warp-level reduction + __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps + + // B matrix is in column-major order, so we can directly load a column for the n_idx expert + T const* b_col = mat_b + n_idx * kHiddenDim; + + // Pre-compute k_base values for each iteration to help compiler optimize + // int k_bases[k_iterations]; + int k_bases[k_iterations]; +#pragma unroll + for (int ki = 0; ki < k_iterations; ki++) { + k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // Process the GEMM in chunks + for (int ki = 0; ki < k_iterations; ki++) { + int const k_base = k_bases[ki]; + + // Load B matrix values using vector load (8 bf16 values) + uint4 b_vec = *reinterpret_cast(b_col + k_base); + + // Convert B values to float + float b_float[VPT]; + bf16_uint4_to_float8(b_vec, b_float); + +// Process each token +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + // Load both rows of A matrix using vector loads + uint4 a_vec = *reinterpret_cast(mat_a + (m_idx * kHiddenDim) + k_base); + + // Convert A values to float + float a_float[VPT]; + bf16_uint4_to_float8(a_vec, a_float); + +// Process elements in this chunk +#pragma unroll + for (int k = 0; k < VPT; k++) { + float a = a_float[k]; + float b = b_float[k]; + acc[m_idx] += a * b; + } + } + } + + // Perform warp-level reduction + int const warpSize = 32; + int const warpId = tid / warpSize; + int const laneId = tid % warpSize; + + // Register for warp-level reduction results + float warp_result[kNumTokens]; + +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + warp_result[m_idx] = acc[m_idx]; + } + +// Perform warp-level reduction using optimized butterfly pattern +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float sum = warp_result[m]; + + // Butterfly reduction pattern + sum += __shfl_xor_sync(0xffffffff, sum, 16); + sum += __shfl_xor_sync(0xffffffff, sum, 8); + sum += __shfl_xor_sync(0xffffffff, sum, 4); + sum += __shfl_xor_sync(0xffffffff, sum, 2); + sum += __shfl_xor_sync(0xffffffff, sum, 1); + + // Only the first thread in each warp stores to shared memory + if (laneId == 0) { + sm_reduction[m][warpId] = sum; + } + } + + __syncthreads(); + + // Final reduction across warps (only first thread) + if (tid == 0) { +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float final_sum = 0.0f; + +// Sum across the kNumWarps +#pragma unroll + for (int w = 0; w < kNumWarps; w++) { + final_sum += sm_reduction[m][w]; + } + + // Write final result + out[m * kNumExperts + n_idx] = final_sum; + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} +} // namespace flashinfer::trtllm_dsv3_router_gemm diff --git a/scripts/task_test_blackwell_kernels.sh b/scripts/task_test_blackwell_kernels.sh index fb35e168af..312cf12eb1 100644 --- a/scripts/task_test_blackwell_kernels.sh +++ b/scripts/task_test_blackwell_kernels.sh @@ -6,6 +6,13 @@ set -eo pipefail : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +# Clean Python bytecode cache to avoid stale imports (e.g., after module refactoring) +echo "Cleaning Python bytecode cache..." +find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true +find . -type f -name '*.pyc' -delete 2>/dev/null || true +echo "Cache cleaned." +echo "" + # Pytest configuration flags PYTEST_FLAGS="--continue-on-collection-errors -s" diff --git a/scripts/task_test_jit_cache_package_build_import.sh b/scripts/task_test_jit_cache_package_build_import.sh index c8e4cfc6b6..e2e4a824aa 100755 --- a/scripts/task_test_jit_cache_package_build_import.sh +++ b/scripts/task_test_jit_cache_package_build_import.sh @@ -28,6 +28,12 @@ export MAX_JOBS : ${CUDA_VISIBLE_DEVICES:=""} echo "CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}" +# Clean Python bytecode cache to avoid stale imports (e.g., after module refactoring) +echo "Cleaning Python bytecode cache..." +find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true +find . -type f -name '*.pyc' -delete 2>/dev/null || true +echo "Cache cleaned." + echo "" echo "Detecting CUDA architecture list..." export FLASHINFER_CUDA_ARCH_LIST=$(python3 -c ' diff --git a/scripts/task_test_multi_node_comm_kernels.sh b/scripts/task_test_multi_node_comm_kernels.sh index 7ece7fe0ad..f1dcedc93b 100644 --- a/scripts/task_test_multi_node_comm_kernels.sh +++ b/scripts/task_test_multi_node_comm_kernels.sh @@ -5,6 +5,13 @@ set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +# Clean Python bytecode cache to avoid stale imports (e.g., after module refactoring) +echo "Cleaning Python bytecode cache..." +find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true +find . -type f -name '*.pyc' -delete 2>/dev/null || true +echo "Cache cleaned." +echo "" + pip install -e . -v pytest -s tests/comm/test_mnnvl_memory.py diff --git a/scripts/task_test_nightly_build.sh b/scripts/task_test_nightly_build.sh index 46f6b76d36..ad7773b5ab 100755 --- a/scripts/task_test_nightly_build.sh +++ b/scripts/task_test_nightly_build.sh @@ -12,6 +12,13 @@ set -x : ${DIST_JIT_CACHE_DIR:=dist-jit-cache} : ${DIST_PYTHON_DIR:=dist-python} +# Clean Python bytecode cache to avoid stale imports (e.g., after module refactoring) +echo "Cleaning Python bytecode cache..." +find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true +find . -type f -name '*.pyc' -delete 2>/dev/null || true +echo "Cache cleaned." +echo "" + # Display GPU information (running inside Docker container with GPU access) echo "=== GPU Information ===" nvidia-smi diff --git a/scripts/task_test_single_node_comm_kernels.sh b/scripts/task_test_single_node_comm_kernels.sh index 4d9c4ff3f3..19593258db 100644 --- a/scripts/task_test_single_node_comm_kernels.sh +++ b/scripts/task_test_single_node_comm_kernels.sh @@ -5,6 +5,13 @@ set -x : ${MAX_JOBS:=$(nproc)} : ${CUDA_VISIBLE_DEVICES:=0} +# Clean Python bytecode cache to avoid stale imports (e.g., after module refactoring) +echo "Cleaning Python bytecode cache..." +find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true +find . -type f -name '*.pyc' -delete 2>/dev/null || true +echo "Cache cleaned." +echo "" + pip install -e . -v # vllm ar diff --git a/tests/gemm/test_group_gemm.py b/tests/gemm/test_group_gemm.py index fbdd9e26e4..739527f726 100644 --- a/tests/gemm/test_group_gemm.py +++ b/tests/gemm/test_group_gemm.py @@ -23,6 +23,7 @@ has_flashinfer_jit_cache, is_sm90a_supported, ) +from flashinfer.jit.gemm import gen_gemm_module, gen_gemm_sm90_module DTYPES = [torch.float16] CUDA_DEVICES = ["cuda:0"] @@ -33,9 +34,9 @@ scope="module", ) def warmup_jit(): - jit_specs = [flashinfer.gemm.gen_gemm_module()] + jit_specs = [gen_gemm_module()] if is_sm90a_supported(torch.device("cuda:0")): - jit_specs.append(flashinfer.gemm.gen_gemm_sm90_module()) + jit_specs.append(gen_gemm_sm90_module()) flashinfer.jit.build_jit_specs(jit_specs, verbose=False) yield diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index 4c90bf3fe9..9d7a7abbbd 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -9,7 +9,7 @@ mxfp4_quantize, ) from flashinfer.utils import get_compute_capability, LibraryError -from flashinfer.gemm import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR +from flashinfer.gemm.gemm_base import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR # TODO: Consdier splitting this function up for the various backends diff --git a/tests/gemm/test_tgv_gemm.py b/tests/gemm/test_tgv_gemm.py index ee7fc67926..0296cbbb54 100755 --- a/tests/gemm/test_tgv_gemm.py +++ b/tests/gemm/test_tgv_gemm.py @@ -6,7 +6,7 @@ tgv_gemm_sm100, ) -from flashinfer.gemm import _match_sm_version +from flashinfer.gemm.gemm_base import _match_sm_version @pytest.mark.parametrize("m", [1, 8, 16, 32, 64]) diff --git a/tests/model_optimizations/test_dsv3_router_gemm.py b/tests/model_optimizations/test_dsv3_router_gemm.py new file mode 100644 index 0000000000..c4c8f1ce7b --- /dev/null +++ b/tests/model_optimizations/test_dsv3_router_gemm.py @@ -0,0 +1,137 @@ +import torch +import pytest +from flashinfer.dsv3_ops import mm_M1_16_K7168_N256 +import torch.nn.functional as F +from flashinfer.utils import get_compute_capability + + +# Positive tests +@pytest.mark.parametrize("num_tokens", [1, 2, 3, 5, 8, 13, 16]) +@pytest.mark.parametrize("num_experts", [256]) +@pytest.mark.parametrize("hidden_dim", [7168]) +@pytest.mark.parametrize("launch_with_pdl", [True, False]) +def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pdl): + compute_capability = get_compute_capability(torch.device("cuda")) + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + if compute_capability_number != 100: + pytest.skip("DSv3 Router GEMM is only supported on SM100") + + mat_a = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + mat_b = torch.randn( + num_experts, hidden_dim, device="cuda", dtype=torch.bfloat16 + ).t() # column major + out = torch.randn(num_tokens, num_experts, device="cuda", dtype=torch.float32) + mm_M1_16_K7168_N256(mat_a, mat_b, out, launch_with_pdl=launch_with_pdl) + ref = mat_a @ mat_b + + cos_sim = F.cosine_similarity(ref.reshape(-1), out.reshape(-1), dim=0) + assert cos_sim > 0.99 + + +# Negative tests - test values just outside valid ranges +@pytest.mark.parametrize( + "num_tokens,num_experts,hidden_dim,mat_a_dtype,mat_b_dtype,out_dtype,mat_b_transpose,expected_error", + [ + # Invalid num_tokens (must be 1-16) + ( + 0, + 256, + 7168, + torch.bfloat16, + torch.bfloat16, + torch.float32, + True, + "num_tokens", + ), + ( + 17, + 256, + 7168, + torch.bfloat16, + torch.bfloat16, + torch.float32, + True, + "num_tokens", + ), + # Invalid num_experts (must be 256) + ( + 8, + 255, + 7168, + torch.bfloat16, + torch.bfloat16, + torch.float32, + True, + "num_experts", + ), + ( + 8, + 257, + 7168, + torch.bfloat16, + torch.bfloat16, + torch.float32, + True, + "num_experts", + ), + # Invalid hidden_dim (must be 7168) + ( + 8, + 256, + 7167, + torch.bfloat16, + torch.bfloat16, + torch.float32, + True, + "hidden_dim", + ), + ( + 8, + 256, + 7169, + torch.bfloat16, + torch.bfloat16, + torch.float32, + True, + "hidden_dim", + ), + # Invalid dtypes + (8, 256, 7168, torch.float32, torch.bfloat16, torch.float32, True, "bfloat16"), + (8, 256, 7168, torch.bfloat16, torch.float32, torch.float32, True, "bfloat16"), + (8, 256, 7168, torch.bfloat16, torch.bfloat16, torch.bfloat16, True, "float32"), + # Invalid stride (mat_b not transposed = row-major instead of column-major) + ( + 8, + 256, + 7168, + torch.bfloat16, + torch.bfloat16, + torch.float32, + False, + "column-major", + ), + ], +) +def test_dsv3_router_gemm_op_negative( + num_tokens, + num_experts, + hidden_dim, + mat_a_dtype, + mat_b_dtype, + out_dtype, + mat_b_transpose, + expected_error, +): + compute_capability = get_compute_capability(torch.device("cuda")) + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + if compute_capability_number != 100: + pytest.skip("DSv3 Router GEMM is only supported on SM100") + + mat_a = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=mat_a_dtype) + mat_b = torch.randn(num_experts, hidden_dim, device="cuda", dtype=mat_b_dtype) + if mat_b_transpose: + mat_b = mat_b.t() # column major + out = torch.randn(num_tokens, num_experts, device="cuda", dtype=out_dtype) + + with pytest.raises(ValueError, match=expected_error): + mm_M1_16_K7168_N256(mat_a, mat_b, out, launch_with_pdl=False) From e450c7dc9e10a680d9fe5b98ed023d83ea676db9 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 7 Nov 2025 14:54:43 -0800 Subject: [PATCH 040/130] Fix moe fp8 failure for sm121 (#2061) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description fix the failure for sm121 in [pipeline](https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/230180150) ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Extended FP8 grouped matrix-multiplication support to include an additional GPU architecture (SM121), providing the same optimized tile configuration options as the previously supported SM variants, improving performance consistency and broader hardware compatibility for FP8 workloads. Co-authored-by: Zihao Ye --- .../tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 34a90c65f9..a20171d8b6 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -158,7 +158,7 @@ std::vector get_candidate_tiles( CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; case CutlassGemmType::Fp8: if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { - if (sm == 89 || sm >= 120) { + if (sm == 89 || sm == 120 || sm == 121) { return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, From ba011d15ed44d32c9c226fd1d66746707233729c Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Sat, 8 Nov 2025 07:14:05 +0100 Subject: [PATCH 041/130] perf: TRT-LLM MoE Block-FP8 activation optimization (#2063) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description - Small optimization to the activation kernel for block-FP8 MoE for large batch size. | BS | Baseline, us | Optimized, us | | ------------- | ------------- | ------------- | | 1 | 2.4 | 2.1 | | 32 | 3.5 | 2.6 | | 256 | 21.7 | 8.7 | | 1024 | 84.4 | 23.8 | | 4096 | 333 | 87.0 | | 16384 | 1330 | 365 | - Adding micro-benchmark for DS FP8 implemented by @IwakuraRein. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Improved Mixture-of-Experts inference with configurable multi-token batching per GPU core for higher throughput. * Expanded FP8 quantization with a new block-scale mode and dynamic, hardware-aware kernel scheduling for better utilization and numerical stability. * Vectorized max-reduction and per-block scaling to accelerate reductions and improve output scaling precision. * Autotuner/CLI now exposes the FP8 block quantization option for tuning. --------- Signed-off-by: Siyuan Fu Co-authored-by: Siyuan Fu --- .../bench_trtllm_gen_fused_moe_autotuner.py | 138 ++++++--- csrc/trtllm_fused_moe_dev_kernel.cu | 273 +++++++++++++++--- .../flashinfer/trtllm/fused_moe/DevKernel.h | 29 +- 3 files changed, 356 insertions(+), 84 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index e7e40e772f..0aff25860e 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -11,6 +11,8 @@ from flashinfer.fused_moe import ( trtllm_fp4_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, + trtllm_fp8_block_scale_moe, + WeightLayout, ) from flashinfer.autotuner import autotune from flashinfer.testing.utils import bench_gpu_time @@ -21,7 +23,7 @@ def fp8_quantize(x): - max = x.float().abs().nan_to_num().max() + max = x.abs().max().float() scale = FLOAT8_E4M3_MAX / max x = (x * scale).to(torch.float8_e4m3fn) return x, 1.0 / scale @@ -29,7 +31,7 @@ def fp8_quantize(x): def bench_trtllm_gen_fused_moe_autotuner_fp8( tune_max_num_tokens: Optional[int], - quant_mode: Literal["Fp8-Per-Tensor"], + quant_mode: Literal["Fp8-Per-Tensor", "Fp8-Block"], num_tokens: int, num_experts: int, hidden_size: int, @@ -41,11 +43,12 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) routing_logits = torch.rand(num_tokens, num_experts, device=device).to( - torch.bfloat16 + torch.float32 ) hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( torch.bfloat16 ) + routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) w13 = torch.randn( num_experts, intermediate_size * 2, hidden_size, device=device ).to(torch.bfloat16) @@ -53,43 +56,97 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( torch.bfloat16 ) - hidden_states, hidden_states_scale = fp8_quantize(hidden_states) - w13, w13_scale = fp8_quantize(w13) - w2, w2_scale = fp8_quantize(w2) + is_block_scale = quant_mode == "Fp8-Block" + if not is_block_scale: + hidden_states, hidden_states_scale = fp8_quantize(hidden_states) + w13, w13_scale = fp8_quantize(w13) + w2, w2_scale = fp8_quantize(w2) + else: + # block scale quantization is too slow, so we use per-tensor quantization for now + hidden_states, hidden_states_scale = fp8_quantize(hidden_states) + w13, w13_scale = fp8_quantize(w13) + w2, w2_scale = fp8_quantize(w2) + hidden_states_scale = torch.full( + (hidden_size // 128, num_tokens), hidden_states_scale.item(), device=device + ) + w13_scale = torch.full( + (num_experts, intermediate_size * 2 // 128, hidden_size // 128), + w13_scale.item(), + device=device, + ) + w2_scale = torch.full( + (num_experts, hidden_size // 128, intermediate_size // 128), + w2_scale.item(), + device=device, + ) - output1_scale_scalar = torch.tensor( - [hidden_states_scale * w13_scale] * num_experts, device=device + output1_scale_scalar = ( + torch.tensor([hidden_states_scale * w13_scale] * num_experts, device=device) + if not is_block_scale + else None ) - output1_scales_gate_scalar = torch.ones( - num_experts, device=device, dtype=torch.float32 + output1_scales_gate_scalar = ( + torch.ones(num_experts, device=device, dtype=torch.float32) + if not is_block_scale + else None ) - output2_scale_scalar = torch.tensor( - [hidden_states_scale * w2_scale] * num_experts, device=device + output2_scale_scalar = ( + torch.tensor([hidden_states_scale * w2_scale] * num_experts, device=device) + if not is_block_scale + else None ) - fn = lambda: trtllm_fp8_per_tensor_scale_moe( - routing_logits, - None, # routing_bias - hidden_states, - w13, - output1_scale_scalar, - output1_scales_gate_scalar, - w2, - output2_scale_scalar, - num_experts, - top_k, - None, # n_group - None, # topk_group - intermediate_size, - 0, # local_expert_offset - num_experts, - 1.0, # routed_scaling_factor - False, # use_routing_scales_on_input - None, - RoutingMethodType.TopK.value, - enable_pdl, - num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, - ) + if is_block_scale: + fn = lambda: trtllm_fp8_block_scale_moe( + routing_logits, + routing_bias, + hidden_states, + hidden_states_scale, + w13, + w13_scale, + w2, + w2_scale, + num_experts, + top_k, + 8, # n_group + 4, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + 2.5, # routed_scaling_factor + None, # tile_tokens_dim + RoutingMethodType.DeepSeekV3.value, + True, # use_shuffled_weight + WeightLayout.BlockMajorK.value, # weight_layout + enable_pdl=enable_pdl, + tune_max_num_tokens=num_tokens + if tune_max_num_tokens is None + else tune_max_num_tokens, + ) + else: + fn = lambda: trtllm_fp8_per_tensor_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + w13, + output1_scale_scalar, + output1_scales_gate_scalar, + w2, + output2_scale_scalar, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + 1.0, # routed_scaling_factor + False, # use_routing_scales_on_input + None, # tile_tokens_dim + RoutingMethodType.TopK.value, + enable_pdl, + num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + ) def bench(do_autotune): with autotune(do_autotune): @@ -135,6 +192,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( torch.tensor([448.0 * 6.0], device=device), sf_vec_size=16, sf_use_ue8m0=False, + is_sf_swizzled_layout=False, ) hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( num_tokens, -1 @@ -263,7 +321,13 @@ def bench(do_autotune): "--quant-mode", type=str, default="MxFP4xMxFP8", - choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16", "Fp8-Per-Tensor"], + choices=[ + "NvFP4xNvFP4", + "MxFP4xMxFP8", + "MxFP4xBf16", + "Fp8-Per-Tensor", + "Fp8-Block", + ], help="Quantization mode", ) parser.add_argument("--num-tokens", type=int, default=512, help="Number of tokens") @@ -288,7 +352,7 @@ def bench(do_autotune): "--iterations", type=int, default=100, help="Number of benchmark iterations" ) args = parser.parse_args() - if args.quant_mode == "Fp8-Per-Tensor": + if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]: bench_trtllm_gen_fused_moe_autotuner_fp8( args.tune_max_num_tokens, args.quant_mode, diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index f596d046b8..9a51384090 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -26,6 +27,7 @@ #include "flashinfer/exception.h" #include "flashinfer/trtllm/fused_moe/DevKernel.h" +#include "flashinfer/utils.cuh" //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -93,13 +95,118 @@ __global__ void activationKernel(KernelParams params) { //////////////////////////////////////////////////////////////////////////////////////////////////// +struct Float4Max { + __device__ __forceinline__ float4 operator()(float4 const& a, float4 const& b) const { + float4 result; + result.x = fmaxf(a.x, b.x); + result.y = fmaxf(a.y, b.y); + result.z = fmaxf(a.z, b.z); + result.w = fmaxf(a.w, b.w); + return result; + } +}; + +struct Float2Max { + __device__ __forceinline__ float2 operator()(float2 const& a, float2 const& b) const { + float2 result; + result.x = fmaxf(a.x, b.x); + result.y = fmaxf(a.y, b.y); + return result; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ VecType packedTypeFromArray(float data[size]) { + return {}; +} + +template <> +__device__ __forceinline__ float4 packedTypeFromArray(float data[4]) { + float4 result; + result.x = data[0]; + result.y = data[1]; + result.z = data[2]; + result.w = data[3]; + return result; +} + +template <> +__device__ __forceinline__ float2 packedTypeFromArray(float data[2]) { + float2 result; + result.x = data[0]; + result.y = data[1]; + return result; +} + +template <> +__device__ __forceinline__ float packedTypeFromArray(float data[1]) { + return data[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ cutlass::Array arrayFromPackedType(PackedType data) { + return cutlass::Array{}; +} + +template <> +__device__ __forceinline__ cutlass::Array arrayFromPackedType(float4 data) { + return cutlass::Array{data.x, data.y, data.z, data.w}; +} + +template <> +__device__ __forceinline__ cutlass::Array arrayFromPackedType(float2 data) { + return cutlass::Array{data.x, data.y}; +} + +template <> +__device__ __forceinline__ cutlass::Array arrayFromPackedType(float data) { + return cutlass::Array{data}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct KernelTraits; + +template <> +struct KernelTraits<4> { + using MaxOp = Float4Max; + using PackedType = float4; +}; + +template <> +struct KernelTraits<2> { + using MaxOp = Float2Max; + using PackedType = float2; +}; + +template <> +struct KernelTraits<1> { +#if CUDA_VERSION >= 12090 + using MaxOp = cuda::maximum<>; +#else + using MaxOp = cub::Max; +#endif + using PackedType = float; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template __global__ void activationDeepSeekKernel(KernelParams params) { using Type = typename KernelParams::Type; - using BlockReduce = cub::BlockReduce; + int32_t constexpr NumTokensPerCta = KernelParams::NumTokensPerCta; + using KernelTraits = KernelTraits; + using MaxOp = typename KernelTraits::MaxOp; + using PackedType = typename KernelTraits::PackedType; + using BlockReduce = cub::BlockReduce; - __shared__ float s_scaleOut; - __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ float s_scaleOutArr[NumTokensPerCta]; + __shared__ typename BlockReduce::TempStorage tempStorage; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // immediately trigger the secondary kernel when using PDL, then wait on primary @@ -108,54 +215,101 @@ __global__ void activationDeepSeekKernel(KernelParams params) { cudaGridDependencySynchronize(); } #endif - // Loop over tokens - for (int tokenIdx = blockIdx.z; tokenIdx < params.numTokens; tokenIdx += gridDim.z) { - // Look over experts per token - for (int k = blockIdx.y; k < params.topK; k += gridDim.y) { - int const expandedIdx = tokenIdx * params.topK + k; - int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; - // Needed for expert parallelism - if (permutedIdx == -1) continue; + // The largest (finite) value that can be represented using E4m3. + float constexpr E4m3MaxVal{448.f}; - // Loop over hidden dim + int const totalNumPaddedTokens = params.totalNumPaddedTokens[0]; + // Loop over tokens + float scale1Arr[NumTokensPerCta]; + float scale2Arr[NumTokensPerCta]; + float dataX1Arr[NumTokensPerCta]; + float dataX2Arr[NumTokensPerCta]; + float outArr[NumTokensPerCta]; + float absOutArr[NumTokensPerCta]; + int permutedIdxArr[NumTokensPerCta]; + + // Loop over tokens + for (int k = blockIdx.z; k < params.topK; k += gridDim.z) { + for (int tokenCtaIdx = blockIdx.y * NumTokensPerCta; tokenCtaIdx < params.numTokens; + tokenCtaIdx += gridDim.y * NumTokensPerCta) { for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2; hiddenIdx += blockDim.x * gridDim.x) { - int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; - - int const totalNumPaddedTokens = params.totalNumPaddedTokens[0]; +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + int const tokenIdx = tokenCtaIdx + tokenInCtaIdx; + if (tokenIdx >= params.numTokens) { + break; + } - int const scale1_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); - int const scale2_idx = - permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); - float const scale1 = params.inDqSfsPtr[scale1_idx]; - float const scale2 = params.inDqSfsPtr[scale2_idx]; - - float x1 = scale1 * (float)params.inPtr[baseIdx]; - float x2 = scale2 * (float)params.inPtr[baseIdx + params.innerDim / 2]; - - float act = silu(x2); - float out = act * x1; + int const expandedIdx = tokenIdx * params.topK + k; + int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; + permutedIdxArr[tokenInCtaIdx] = permutedIdx; + if (permutedIdx == -1) { + continue; + } + + // Process blocks for this CTA + int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; + + int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); + int const scale2Idx = permutedIdx + totalNumPaddedTokens * + ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); + + scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx]; + scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx]; + dataX1Arr[tokenInCtaIdx] = static_cast(params.inPtr[baseIdx]); + dataX2Arr[tokenInCtaIdx] = + static_cast(params.inPtr[baseIdx + params.innerDim / 2]); + } - // The largest (finite) value that can be represented using E4m3. - float constexpr E4m3MaxVal{448.f}; +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx]; + float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx]; + float act = silu(x2); + float out = act * x1; + outArr[tokenInCtaIdx] = out; + absOutArr[tokenInCtaIdx] = fabsf(out); + } - // Compute the absolute max -#if CUDA_VERSION >= 12090 - float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cuda::maximum<>{}); -#else - float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cub::Max{}); -#endif - if (threadIdx.x == 0) { - s_scaleOut = aMax / E4m3MaxVal; - int const scaleOut_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); - params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal; + auto absOutPacked = packedTypeFromArray(absOutArr); + auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{}); + auto aMaxArr = arrayFromPackedType(aMaxPacked); + +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + if (threadIdx.x == 0) { + auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; + if (tokenIdx >= params.numTokens) { + break; + } + int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; + if (permutedIdx == -1) { + continue; + } + s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; + int const scaleOut_idx = + permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128); + params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; + } } __syncthreads(); - float const scaleOut = s_scaleOut; - __syncthreads(); - int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; - params.outPtr[outIdx] = (Type)(out / scaleOut); + +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; + if (tokenIdx >= params.numTokens) { + break; + } + int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; + if (permutedIdx == -1) { + continue; + } + float const scaleOut = s_scaleOutArr[tokenInCtaIdx]; + int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; + params.outPtr[outIdx] = static_cast(outArr[tokenInCtaIdx] / scaleOut); + } } } } @@ -172,15 +326,42 @@ void run(Data const& data, void* stream) { } if (data.mUseDeepSeekFp8) { - int const numThreads = 128; - const dim3 grid(data.innerDim / 128, data.topK, data.numTokens); + constexpr int NUM_ELTS_PER_LOAD = 1; + constexpr int NUM_ELTS_PER_SF = 128; + int const NUM_THREADS_PER_CTA = 128; + + int device{-1}; + cudaGetDevice(&device); + int numSms = 0; + cudaDeviceGetAttribute(&numSms, cudaDevAttrMultiProcessorCount, device); + + // Output dimension is innerDim / 2, and each scale block is 128 elements + int const outputDim = data.innerDim / 2; + int const numScaleBlocks = (outputDim + NUM_ELTS_PER_SF - 1) / NUM_ELTS_PER_SF; + int const gridSizeX = (numScaleBlocks + NUM_ELTS_PER_LOAD - 1) / NUM_ELTS_PER_LOAD; + + auto numCtas = gridSizeX * data.numTokens * data.topK; + // FIXME: This is heruistic based on very short benchmark. + int numTokensPerCta = 1; + if (numCtas > numSms * 32) { + numTokensPerCta = 4; + } else if (numCtas > numSms * 4) { + numTokensPerCta = 2; + } else { + numTokensPerCta = 1; + } + + int const gridSizeY = std::min(8192, (data.numTokens + numTokensPerCta - 1) / numTokensPerCta); + + const dim3 grid(gridSizeX, gridSizeY, data.topK); - LAUNCH(data, activationDeepSeekKernel, grid, numThreads, 0, stream); + LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid, NUM_THREADS_PER_CTA, 0, + stream); } else { int const numThreads = 256; const dim3 grid(data.innerDim / 128, data.topK, data.numTokens); - LAUNCH(data, activationKernel, grid, numThreads, 0, stream); + LAUNCH_ACTIVATION(data, activationKernel, 1, grid, numThreads, 0, stream); } } diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index e3a0d21884..0ee9ba6fe9 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -90,6 +90,32 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported dtypeElt"); \ } +#define LAUNCH_NUM_TOKENS_PER_CTA(data, type, numTokensPerCta, kernel, numBlocks, numThreads, \ + smemSize, stream) \ + if (numTokensPerCta == 4) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(type, 4), kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (numTokensPerCta == 2) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(type, 2), kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (numTokensPerCta == 1) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(type, 1), kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported numTokensPerCta"); \ + } + +#define LAUNCH_ACTIVATION(data, kernel, numTokensPerCta, numBlocks, numThreads, smemSize, stream) \ + if (data.mDtypeElt == tg::Dtype::Fp16) { \ + LAUNCH_NUM_TOKENS_PER_CTA(data, cutlass::half_t, numTokensPerCta, kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::E4m3) { \ + LAUNCH_NUM_TOKENS_PER_CTA(data, cutlass::float_e4m3_t, numTokensPerCta, kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Bfloat16) { \ + LAUNCH_NUM_TOKENS_PER_CTA(data, cutlass::bfloat16_t, numTokensPerCta, kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeElt"); \ + } + #define LAUNCH_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, \ @@ -234,9 +260,10 @@ struct Data { int32_t const* totalNumPaddedTokens; }; -template +template struct KernelParams { using Type = Type_; + static constexpr int32_t NumTokensPerCta = NumTokensPerCta_; static constexpr bool UsePdl = UsePdl_; Type const* inPtr; From 74281ed4b326a72a3ce758fad1338e81ad6abe6f Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 7 Nov 2025 22:24:15 -0800 Subject: [PATCH 042/130] [feat] Refactor trtllmgen MOE and add Bf16 trtllmgen moe (#2014) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description - Refactor `trtllm_fused_moe_kernel_launcher.cu` to use class structure for code cleanliness and readability - Add BF16 MOE, initial PR (https://github.com/flashinfer-ai/flashinfer/pull/1859) from @aleozlx and @nekorobov - Add BF16 MOE autotune ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * BF16 Mixture-of-Experts (MoE) pathway added with autotuning and public API access. * **Improvements** * Unified BF16/FP8/FP4/FP16 pathways with clearer dtype compatibility checks and corrected operator return semantics. * Routing selection now respects token-size and input packing, and diagnostics produce more descriptive error messages. * **Tests** * Expanded BF16 test coverage across routing modes, weight layouts, and token sizes. * **Chores** * Updated artifact metadata and checksums. --------- Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- csrc/trtllm_batched_gemm_runner.cu | 18 +- csrc/trtllm_fused_moe_kernel_launcher.cu | 2524 ++++++++++-------- csrc/trtllm_fused_moe_routing_renormalize.cu | 4 +- flashinfer/artifacts.py | 6 +- flashinfer/fused_moe/__init__.py | 4 + flashinfer/fused_moe/core.py | 285 +- tests/moe/test_trtllm_gen_fused_moe.py | 234 +- 7 files changed, 1932 insertions(+), 1143 deletions(-) diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index 42fe8f7f59..cff4db198f 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -116,14 +116,16 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( } } - FLASHINFER_CHECK( - !mPassingConfigIndices.empty(), - "No kernel found for the given options: mDtypeA: %s, mDtypeB: %s, mDtypeC: %s, " - "mUseDeepSeekFp8: %d, " - "mTransposeMmaOutput: %d, mRouteAct: %d, mFusedAct: %d, mIsStaticBatch: %d, mTileSize: %d", - tg::dtypeToString(mOptions.dtypeA).c_str(), tg::dtypeToString(mOptions.dtypeB).c_str(), - tg::dtypeToString(mOptions.dtypeC).c_str(), mOptions.deepSeekFp8, mOptions.transposeMmaOutput, - mOptions.routeAct, mOptions.fusedAct, mOptions.staticBatch, mOptions.tileSize); + std::ostringstream error_msg; + error_msg << "No kernel found for the given options: " + << "mDtypeA: " << tg::dtypeToString(mOptions.dtypeA) + << ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB) + << ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC) + << ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8 + << ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput + << ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct + << ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize; + FLASHINFER_CHECK(!mPassingConfigIndices.empty(), error_msg.str()); } size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes( diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 3fd9dab35e..0688c1e97d 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include #include @@ -83,1095 +85,1413 @@ std::set computeSelectedTileN(std::vector const& supported_til return selected_tile_nums; } -void trtllm_fp8_per_tensor_scale_moe_launcher( - TensorView routing_logits, Optional routing_bias, TensorView hidden_states, - TensorView gemm1_weights, TensorView output1_scales_scalar, - TensorView output1_scales_gate_scalar, TensorView gemm2_weights, - TensorView output2_scales_scalar, TensorView output, int64_t const num_experts, - int64_t const top_k, Optional const n_group, Optional const topk_group, - int64_t const intermediate_size, int64_t const local_expert_offset, - int64_t const local_num_experts, Optional const routed_scaling_factor, - bool const use_routing_scales_on_input, int64_t const tile_tokens_dim, - int64_t const routing_method_type, - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex, - bool enable_pdl) { - static const std::tuple device_props = [hidden_states] { - int major, minor; - cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, - hidden_states.device().device_id); - cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, - hidden_states.device().device_id); - return std::make_tuple(major, minor); - }(); - - TVM_FFI_ICHECK_EQ(std::get<0>(device_props), 10) - << "This kernel requires 10.x architecture. Current device has SM " - << std::get<0>(device_props) << std::get<1>(device_props); +class FusedMoeLauncher { + protected: + Optional routing_logits; + Optional routing_bias; + TensorView hidden_states; + TensorView gemm1_weights; + Optional output1_scales_scalar; + Optional output1_scales_gate_scalar; + TensorView gemm2_weights; + Optional output2_scales_scalar; + + int64_t tile_tokens_dim{}; + int64_t routing_method_type{}; + bool use_shuffled_weight{}; + batchedGemm::gemm::MatrixLayout weight_layout{batchedGemm::gemm::MatrixLayout::MajorK}; + + std::tuple device_version; + std::unique_ptr args; + tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace; - if (use_routing_scales_on_input) { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; - } else if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; - } else { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; + btg::Dtype mDtypeAct{btg::Dtype::Bfloat16}; + btg::Dtype mDtypeWeights{btg::Dtype::Bfloat16}; + btg::Dtype mRoutingBiasDtype{ + btg::Dtype::Bfloat16}; // Dtype for expert weights in routing, based on routing bias + GatedActType gated_act_type{GatedActType::SwiGlu}; + + public: + // Constructor that initializes all TensorView members + FusedMoeLauncher(const Optional& routing_logits, + const Optional& routing_bias, const TensorView& hidden_states, + const TensorView& gemm1_weights, + const Optional& output1_scales_scalar, + const Optional& output1_scales_gate_scalar, + const TensorView& gemm2_weights, + const Optional& output2_scales_scalar) + : routing_logits(routing_logits), + routing_bias(routing_bias), + hidden_states(hidden_states), + gemm1_weights(gemm1_weights), + output1_scales_scalar(output1_scales_scalar), + output1_scales_gate_scalar(output1_scales_gate_scalar), + gemm2_weights(gemm2_weights), + output2_scales_scalar(output2_scales_scalar), + tile_tokens_dim{}, + routing_method_type{}, + use_shuffled_weight{}, + weight_layout{batchedGemm::gemm::MatrixLayout::MajorK}, + mDtypeAct{btg::Dtype::Bfloat16}, + mDtypeWeights{btg::Dtype::Bfloat16}, + gated_act_type{GatedActType::SwiGlu} {} + + protected: + // Initialize common data necessary for later. + // May throw exception from TVM_FFI_ICHECK. + void init_common(std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, int64_t gated_act_type); + + // Routing logits [num_tokens, num_experts] + void check_routing_logits_shape() const { + if (routing_logits.has_value()) { + TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits.value().size(0), hidden_states.size(0)) + << "routing_logits and hidden_states must have the same number of tokens."; + TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), args->num_experts) + << "routing_logits dim1 must match num_experts."; + } } - TVM_FFI_ICHECK_EQ(routing_logits.ndim(), 2) << "routing_logits must be 2D."; - TVM_FFI_ICHECK_EQ(routing_logits.size(1), num_experts) << "routing_logits has incorrect shape."; - if (routing_bias.has_value()) { - TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || - routing_bias.value().dtype() == dl_float32) - << "routing_bias must be bfloat16 or float."; - TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; - TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts) - << "routing_bias has incorrect shape."; + + // Routing bias [num_experts] + void check_routing_bias_shape() const { + if (routing_bias.has_value()) { + TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; + TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), args->num_experts) + << "routing_bias has incorrect shape."; + } } - if (n_group.has_value() && n_group.value() != 0) { - TVM_FFI_ICHECK(static_cast(routing_method_type) == - RoutingMethodType::DeepSeekV3) - << "Routing kernel with groups implies DeepSeekV3 routing method."; - TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given"; - TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0) - << "num_experts must be divisible by n_group"; - TVM_FFI_ICHECK(top_k <= 8 && top_k > 0) - << "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."; - TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0) - << "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."; - TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value()) - << "n_group must not be smaller than topk_group."; - // This check ensures we have enough experts in the selected groups to handle the top_k routing - TVM_FFI_ICHECK_LT(top_k, (topk_group.value() * num_experts / n_group.value())) - << "top_k must be less than total number of experts in selected groups"; - } else if (static_cast(routing_method_type) == - RoutingMethodType::Renormalize || - static_cast(routing_method_type) == - RoutingMethodType::RenormalizeNaive) { - TVM_FFI_LOG_AND_THROW(NotImplementedError) - << "Don't support routing method type Renormalize(Naive)."; - } else if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { - TVM_FFI_ICHECK_EQ(top_k, 1) - << "Current routing kernel (no groups, Llama4) only supports top_k=1."; + // Hidden states [num_tokens, hidden_size] + void check_hidden_states_shape() const { + TVM_FFI_ICHECK_EQ(hidden_states.ndim(), 2) << "hidden_states must be 2D."; + TVM_FFI_ICHECK_EQ(hidden_states.size(1), args->intermediate_size) + << "hidden_states has incorrect shape."; } - TVM_FFI_ICHECK_EQ(num_experts % 4, 0) - << "Routing kernel expects that num_experts must be divisible by 4"; - TVM_FFI_ICHECK_GT(num_experts, top_k) << "num_experts must be greater than top_k"; - TVM_FFI_ICHECK_LE(local_num_experts + local_expert_offset, num_experts) - << "num_experts must be greater or equal to local_num_experts + local_expert_offset"; - tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs args; - tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace; + // GEMM1 or GEMM2 weights [num_experts, M, K] or [num_experts, K/block_k, M, block_k] + void check_weights_shape(std::string which_weights) const { + TensorView weights = (which_weights == "gemm1") ? gemm1_weights : gemm2_weights; + if (which_weights != "gemm1" && which_weights != "gemm2") { + TVM_FFI_LOG_AND_THROW(InternalError) << "Internal error: which_weights = " << which_weights; + } - // Convert PyTorch dtype to TensorRT-LLM dtype - auto dtype = hidden_states.dtype(); - if (dtype == dl_float16) { - args.mDtypeElt = btg::Dtype::Fp16; - } else if (dtype == dl_bfloat16) { - args.mDtypeElt = btg::Dtype::Bfloat16; - } else if (dtype == dl_float8_e4m3fn) { - args.mDtypeElt = btg::Dtype::E4m3; - } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; + int64_t Mn = 0, K = 0; + if (weight_layout == batchedGemm::gemm::MatrixLayout::MajorK) { + // MajorK [num_experts, M, K] + Mn = weights.size(1); + K = weights.size(2); + } else if (weight_layout == batchedGemm::gemm::MatrixLayout::BlockMajorK) { + // BlockMajorK [num_experts, K/block_k, M, block_k] + Mn = weights.size(2); + int64_t block_k = weights.size(3); + K = weights.size(1) * block_k; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "Unsupported weight_layout: " << (int)weight_layout; + } + if (which_weights == "gemm1") { + TVM_FFI_ICHECK_EQ(Mn % 2, 0) << which_weights << " weights Mn dimension must be even."; + TVM_FFI_ICHECK_EQ(args->intermediate_size, Mn / 2) + << "intermediate_size has incorrect shape."; + TVM_FFI_ICHECK_EQ(K, hidden_states.size(1)) + << which_weights << " weights K dimension must be equal to hidden_size."; + } else if (which_weights == "gemm2") { + TVM_FFI_ICHECK_EQ(K, args->intermediate_size) + << which_weights << " weights K dimension must be equal to intermediate_size."; + } } - args.mDtypeOut = btg::Dtype::Bfloat16; // Output is always bfloat16 for fp8 per-tensor scale - - args.routing_logits = routing_logits.data_ptr(); - auto const routing_bias_dtype = - routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; - auto btg_routing_bias_dtype = btg::Dtype::Fp32; - if (routing_bias_dtype == dl_bfloat16) { - btg_routing_bias_dtype = btg::Dtype::Bfloat16; + + void check_routing_common() const { + TVM_FFI_ICHECK(args->top_k > 0 && args->top_k <= args->num_experts) + << "top_k must be between 1 and num_experts"; + TVM_FFI_ICHECK(args->local_num_experts > 0 && args->local_num_experts <= args->num_experts) + << "local_num_experts must be between 1 and num_experts"; + TVM_FFI_ICHECK(args->local_expert_offset >= 0 && + args->local_expert_offset + args->local_num_experts <= args->num_experts) + << "expert offset and count must be within valid range"; + + check_routing_logits_shape(); + + if (routing_bias.has_value()) { + check_routing_bias_shape(); + } } - args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; - args.hidden_states = hidden_states.data_ptr(); - args.gemm1_weights = gemm1_weights.data_ptr(); - args.output1_scales_scalar = static_cast(output1_scales_scalar.data_ptr()); - args.output1_scales_gate_scalar = static_cast(output1_scales_gate_scalar.data_ptr()); - args.gemm2_weights = gemm2_weights.data_ptr(); - args.output2_scales_scalar = static_cast(output2_scales_scalar.data_ptr()); - args.num_tokens = hidden_states.size(0); - args.num_experts = num_experts; - args.hidden_size = hidden_states.size(1); - args.hidden_size_output = args.hidden_size; - args.top_k = top_k; - args.n_group = n_group.has_value() ? n_group.value() : 0; - args.topk_group = topk_group.has_value() ? topk_group.value() : 0; - args.local_expert_offset = local_expert_offset; - args.local_num_experts = local_num_experts; - args.routed_scaling_factor = - routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0; - args.intermediate_size = intermediate_size; - args.mUseRoutingScalesOnInput = use_routing_scales_on_input; - - // allocate workspace for routing kernel - Tensor num_tokens_per_expert = alloc_tensor({num_experts}, dl_int32, routing_logits.device()); - int32_t max_num_padded_tokens = - tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( - args.num_tokens, top_k, num_experts, tile_tokens_dim); - int32_t max_num_padded_tokens_gemm1 = - tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( - max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt)); - int32_t max_num_padded_tokens_gemm2 = - tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( - max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); - - Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits.device()); - Tensor expanded_idx_to_permuted_idx = - alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits.device()); - Tensor permuted_idx_to_token_idx = - alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits.device()); - Tensor expert_weights = - alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits.device()); - Tensor expert_indexes = - alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits.device()); - int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); - Tensor expert_count_histogram = alloc_tensor( - {size_of_expert_count_histogram}, - dl_int32, // 256 is the max number of threads per block and max number of experts - routing_logits.device()); - - // allocate workspace for activation/gemm/finalize kernels - Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8, - hidden_states.device()); - Tensor gemm1_output_scale = - alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, - hidden_states.device()); - Tensor activation_output = alloc_tensor({max_num_padded_tokens_gemm1, intermediate_size}, - dl_uint8, hidden_states.device()); - Tensor activation_output_scale = alloc_tensor( - {intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states.device()); - Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16, - hidden_states.device()); - int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( - args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); - Tensor cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, routing_logits.device()); - Tensor cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, routing_logits.device()); - Tensor num_non_exiting_ctas = alloc_tensor({1}, dl_int32, routing_logits.device()); - - tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); - cudaStream_t stream = get_stream(routing_logits.device()); - routing_runner.run( - routing_logits.data_ptr(), args.routing_bias, args.num_tokens, args.num_experts, args.top_k, - args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, - args.routed_scaling_factor, static_cast(expert_indexes.data_ptr()), - static_cast(expert_count_histogram.data_ptr()), - static_cast(total_num_padded_tokens.data_ptr()), - static_cast(expanded_idx_to_permuted_idx.data_ptr()), - nullptr /*static_cast(permuted_idx_to_expanded_idx.data_ptr())*/, - static_cast(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), - static_cast(num_tokens_per_expert.data_ptr()), - static_cast(cta_idx_xy_to_batch_idx.data_ptr()), - static_cast(cta_idx_xy_to_mn_limit.data_ptr()), - static_cast(num_non_exiting_ctas.data_ptr()), args.mDtypeElt, btg_routing_bias_dtype, - use_routing_scales_on_input, false /* use_deep_seek_fp8 */, - static_cast(routing_method_type), stream); - - // MoE kernel except routing - TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; - TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; - TVM_FFI_ICHECK_EQ(gemm1_weights.ndim(), 3) << "gemm1_weights must be 3D."; - TVM_FFI_ICHECK_EQ(gemm1_weights.size(1) % 2, 0) - << "the second dimension of weights must be even."; - TVM_FFI_ICHECK_EQ(intermediate_size, gemm1_weights.size(1) / 2) - << "intermediate_size has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm1_weights.size(2), hidden_states.size(1)) - << "the third dimension of weights must be equal to hidden_size."; - TVM_FFI_ICHECK_EQ(intermediate_size % 128, 0) - << "the second dimension of weights must be a multiple of 128."; + // Routing phase workspace tensors (allocated in prepare_routing() or prepare_routing_common()) + Tensor num_tokens_per_expert; + Tensor total_num_padded_tokens; + Tensor expanded_idx_to_permuted_idx; + Tensor permuted_idx_to_token_idx; + Tensor expert_weights; + Tensor expert_indexes; + Tensor expert_count_histogram; + Tensor cta_idx_xy_to_batch_idx; + Tensor cta_idx_xy_to_mn_limit; + Tensor num_non_exiting_ctas; + + void prepare_routing_common() { + // Allocate routing phase workspace tensors + num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, hidden_states.device()); + int32_t max_num_padded_tokens = + tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( + args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + + total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states.device()); + + expanded_idx_to_permuted_idx = + alloc_tensor({args->num_tokens * args->top_k}, dl_int32, hidden_states.device()); + + permuted_idx_to_token_idx = + alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device()); + + expert_indexes = + alloc_tensor({args->num_tokens, args->top_k}, dl_int32, hidden_states.device()); + + // expert_weights allocation should be done by derived class since data type could vary + + int64_t const size_of_expert_count_histogram = std::max(args->num_experts * 2, 256 * 2); + expert_count_histogram = alloc_tensor({size_of_expert_count_histogram}, + dl_int32, // 256 is the max number of threads per block + // and max number of experts + hidden_states.device()); + + int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( + args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + + cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); + + cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); + + num_non_exiting_ctas = alloc_tensor({1}, dl_int32, hidden_states.device()); + + workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); + workspace.total_max_padded_tokens = max_num_padded_tokens; + workspace.ProjUpTileN = tile_tokens_dim; + workspace.routing_expert_indexes = static_cast(expert_indexes.data_ptr()); + workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); + workspace.expanded_idx_to_permuted_idx = + static_cast(expanded_idx_to_permuted_idx.data_ptr()); + workspace.permuted_idx_to_token_idx = static_cast(permuted_idx_to_token_idx.data_ptr()); + // workspace.expert_weights will be set by derived class after expert_weights allocation + workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx.data_ptr()); + workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit.data_ptr()); + workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas.data_ptr()); + } - TVM_FFI_ICHECK_EQ(output1_scales_scalar.dtype(), dl_float32) - << "output1_scales_scalar must be float."; - TVM_FFI_ICHECK_EQ(output1_scales_scalar.ndim(), 1) << "output1_scales_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output1_scales_scalar.size(0), local_num_experts) - << "output1_scales_scalar has incorrect dim 0."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.dtype(), dl_float32) - << "output1_scales_gate_scalar must be float."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.ndim(), 1) - << "output1_scales_gate_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.size(0), local_num_experts) - << "output1_scales_gate_scalar has incorrect dim 0."; + void check_moe_common() const { + // Hidden states [num_tokens, hidden_size] + TVM_FFI_ICHECK_EQ(hidden_states.ndim(), 2) << "hidden_states must be 2D."; + } + + // MoE computation phase workspace tensors (allocated in prepare_moe() or prepare_moe_common()) + Tensor gemm1_output; + Tensor activation_output; + Tensor gemm2_output; + Tensor workspace_fc1; + Tensor workspace_fc2; + Tensor output; + int64_t moe_tactic{-1}; + std::unique_ptr moe_runner; + + void prepare_moe_common(int64_t& moe_tactic) { + using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; + // For FP8 block-scale (E4m3 activations, E4m3 weights) with DeepSeek FP8, use the + // weights-only Runner constructor to match the original kernel path and numerics. + if (this->mDtypeAct == btg::Dtype::E4m3 && this->mDtypeWeights == btg::Dtype::E4m3 && + args->mUseDeepSeekFp8) { + moe_runner = std::make_unique(this->mDtypeWeights, args->mUseDeepSeekFp8, + (int32_t)tile_tokens_dim, this->use_shuffled_weight, + this->weight_layout); + } else { + moe_runner = std::make_unique(this->mDtypeAct, this->mDtypeWeights, + args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, + static_cast(this->gated_act_type), + this->use_shuffled_weight, this->weight_layout); + } + + if (moe_tactic == -1) { + moe_tactic = moe_runner->getDefaultValidConfigIndex( + args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, + args->num_tokens); + } + this->moe_tactic = moe_tactic; - TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; - TVM_FFI_ICHECK_EQ(gemm2_weights.ndim(), 3) << "gemm2_weights must be 3D."; - TVM_FFI_ICHECK_EQ(gemm2_weights.size(2), intermediate_size) - << "the third dimension of weights must be equal to intermediate_size."; + auto workspace_sizes = moe_runner->getWorkspaceSizeInBytes(*args, moe_tactic); + workspace_fc1 = alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device()); + workspace_fc2 = alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states.device()); + workspace.bmm1_workspace = workspace_fc1.data_ptr(); + workspace.bmm2_workspace = workspace_fc2.data_ptr(); + } - TVM_FFI_ICHECK_EQ(output2_scales_scalar.dtype(), dl_float32) - << "output2_scales_scalar must be float."; - TVM_FFI_ICHECK_EQ(output2_scales_scalar.ndim(), 1) << "output2_scales_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output2_scales_scalar.size(0), local_num_experts) - << "output2_scales_scalar has incorrect dim 0."; - - // allocate output - TVM_FFI_ICHECK_EQ(output.size(0), args.num_tokens); - TVM_FFI_ICHECK_EQ(output.size(1), args.hidden_size); - CHECK_INPUT_TYPE(output, dl_bfloat16); - CHECK_DEVICE(output, hidden_states); - - // setup workspace - workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); - workspace.total_max_padded_tokens = - std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); - workspace.ProjUpTileN = tile_tokens_dim; - workspace.routing_expert_indexes = static_cast(expert_indexes.data_ptr()); - workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); - workspace.expanded_idx_to_permuted_idx = static_cast( - expanded_idx_to_permuted_idx.data_ptr()); // Needed by activation/finalize kernels - workspace.permuted_idx_to_token_idx = - static_cast(permuted_idx_to_token_idx.data_ptr()); // Needed by permuteGemm1 kernel - workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel - - workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx.data_ptr()); - workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit.data_ptr()); - workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas.data_ptr()); - - // gemm1 intermediate ws - workspace.gemm1_output = gemm1_output.data_ptr(); - workspace.gemm1_output_scale = static_cast(gemm1_output_scale.data_ptr()); - // activation intermediate ws - workspace.activation_output = activation_output.data_ptr(); - workspace.activation_output_scale = static_cast(activation_output_scale.data_ptr()); - // gemm2 intermediate ws - workspace.gemm2_output = gemm2_output.data_ptr(); - workspace.gemm2_output_scale = nullptr; - args.output = output.data_ptr(); - args.output_scale = nullptr; - - auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); - Tensor workspace_fc1 = - alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device()); - Tensor workspace_fc2 = - alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states.device()); - workspace.bmm1_workspace = workspace_fc1.data_ptr(); - workspace.bmm2_workspace = workspace_fc2.data_ptr(); - cudaStream_t moe_stream = get_stream(hidden_states.device()); - moe_runner.run(args, workspace, hidden_states.device().device_id, moe_stream, moeConfigIndex, - enable_pdl); + public: + virtual void check_routing() const = 0; + virtual void prepare_routing() = 0; + virtual void check_moe() const = 0; + virtual void prepare_moe(int64_t& moe_tactic) = 0; + + // Main entry point for all the executions. + // Do initializations prior to calling this as the initializations are different for bf16, fp8 and + // fp4. The executions are non-blocking by default. + virtual Array run(int64_t moe_tactic, bool enable_pdl = true, + bool use_routing_scales_on_input = false, + bool use_deep_seek_fp8 = false) { + check_routing(); + prepare_routing(); + + // Execute routing + tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); + cudaStream_t routing_stream = get_stream(hidden_states.device()); + + routing_runner.run( + args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, + args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, + args->routed_scaling_factor, static_cast(expert_indexes.data_ptr()), + static_cast(expert_count_histogram.data_ptr()), + static_cast(total_num_padded_tokens.data_ptr()), + static_cast(expanded_idx_to_permuted_idx.data_ptr()), + nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, + static_cast(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), + static_cast(num_tokens_per_expert.data_ptr()), + static_cast(cta_idx_xy_to_batch_idx.data_ptr()), + static_cast(cta_idx_xy_to_mn_limit.data_ptr()), + static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, + use_routing_scales_on_input, use_deep_seek_fp8, + static_cast(routing_method_type), routing_stream); + + check_moe(); + prepare_moe(moe_tactic); + + cudaStream_t moe_stream = get_stream(hidden_states.device()); + moe_runner->run(*args, workspace, hidden_states.device().device_id, moe_stream, moe_tactic, + enable_pdl); + + if (args->do_finalize) { + return {output}; + } + return {gemm2_output, expert_weights, expanded_idx_to_permuted_idx}; + } +}; + +void FusedMoeLauncher::init_common( + std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, int64_t gated_act_type) { + // Check devicearchitecture: Blackwell (SM 10.x) required + auto device = hidden_states.device().device_id; + int major = 0, minor = 0; + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); + TVM_FFI_ICHECK_EQ(major, 10) << "MoE kernel requires 10.x architecture. Current device has SM " + << major << minor; + this->device_version = std::make_tuple(major, minor); + + args->routing_logits = routing_logits.has_value() ? routing_logits.value().data_ptr() : nullptr; + args->routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; + args->hidden_states = hidden_states.data_ptr(); + args->gemm1_weights = gemm1_weights.data_ptr(); + args->gemm2_weights = gemm2_weights.data_ptr(); + + this->args = std::move(args); + this->tile_tokens_dim = tile_tokens_dim; + this->routing_method_type = routing_method_type; + this->use_shuffled_weight = use_shuffled_weight; + TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) + << "the value of weight_layout is not recognized"; + this->weight_layout = static_cast(weight_layout); + TVM_FFI_ICHECK(0 <= gated_act_type && gated_act_type <= 1) + << "the value of gated_act_type is not recognized"; + this->gated_act_type = static_cast(gated_act_type); } -void trtllm_fp8_per_tensor_scale_moe( - TensorView routing_logits, Optional routing_bias, TensorView hidden_states, - TensorView gemm1_weights, TensorView output1_scales_scalar, - TensorView output1_scales_gate_scalar, TensorView gemm2_weights, - TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k, - Optional n_group, Optional topk_group, int64_t intermediate_size, - int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl, - Array config_index) { - auto dtype = hidden_states.dtype(); - if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) { - using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; +class Bf16MoeLauncher : public FusedMoeLauncher { + public: + static constexpr std::array mSupportedTileNums = {8, 16, 32, 64, 128}; + + Bf16MoeLauncher(TensorView const& routing_logits, Optional const& routing_bias, + TensorView const& hidden_states, TensorView const& gemm1_weights, + TensorView const& gemm2_weights) + : FusedMoeLauncher(Optional(routing_logits), routing_bias, hidden_states, + gemm1_weights, Optional(), Optional(), + gemm2_weights, Optional()) {} + + void init(std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout) { + constexpr int64_t gated_act_type = + static_cast(GatedActType::SwiGlu); // not exposed in api for now + + // Do base class init and perform common checks + FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, + use_shuffled_weight, weight_layout, gated_act_type); + } + + void check_routing() const override { + FusedMoeLauncher::check_routing_common(); + + // TODO n_group, topk_group validation? + } + + void prepare_routing() override { + FusedMoeLauncher::prepare_routing_common(); + + args->mDtypeElt = btg::Dtype::Bfloat16; + args->mUseDeepSeekFp8 = false; + + // Set expert weights dtype based on routing bias + auto const routing_bias_dtype = + routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; + mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + + expert_weights = + alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + + workspace.expert_weights = expert_weights.data_ptr(); + } + + void check_moe() const override { + FusedMoeLauncher::check_moe_common(); + + TVM_FFI_ICHECK(weight_layout == batchedGemm::gemm::MatrixLayout::BlockMajorK) + << "BF16 Moe: weight_layout must be BlockMajorK"; + check_weights_shape("gemm1"); + check_weights_shape("gemm2"); - // Convert PyTorch dtype to TensorRT-LLM dtype - btg::Dtype mDtypeElt; + TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) + << "the second dimension of weights must be a multiple of 128."; + } + + void prepare_moe(int64_t& moe_tactic) override { + FusedMoeLauncher::prepare_moe_common(moe_tactic); + + int32_t max_num_padded_tokens = workspace.total_max_padded_tokens; + gemm1_output = alloc_tensor({max_num_padded_tokens, args->intermediate_size}, dl_bfloat16, + hidden_states.device()); + activation_output = alloc_tensor({max_num_padded_tokens, args->intermediate_size}, dl_bfloat16, + hidden_states.device()); + gemm2_output = alloc_tensor({max_num_padded_tokens, args->hidden_size}, dl_bfloat16, + hidden_states.device()); + + workspace.hidden_states_scale_linear = nullptr; + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale = nullptr; + workspace.activation_output = activation_output.data_ptr(); + workspace.activation_output_scale = nullptr; + workspace.gemm2_output = gemm2_output.data_ptr(); + workspace.gemm2_output_scale = nullptr; + + output = + alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device()); + args->output = output.data_ptr(); + args->output_scale = nullptr; + } + + static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, + int64_t intermediate_size, int64_t num_local_experts, + int64_t num_tokens, int64_t gated_act_type, + bool use_shuffled_weight, int64_t weight_layout) { + Array> valid_configs; + + std::vector supported_tile_nums(mSupportedTileNums.begin(), mSupportedTileNums.end()); + std::set selected_tile_nums = + computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + + for (int32_t tile_N : selected_tile_nums) { + auto moe_runner = std::make_unique( + btg::Dtype::Bfloat16, // dtype_act + btg::Dtype::Bfloat16, // dtype_weights + false, // useDeepSeekFp8 + tile_N, static_cast(gated_act_type), use_shuffled_weight, + static_cast(weight_layout)); + + auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); + + for (auto cfg : cfgs) { + valid_configs.push_back({tile_N, cfg}); + } + } + + return valid_configs; + } +}; + +class Fp8PerTensorLauncher : public FusedMoeLauncher { + public: + static constexpr std::array mSupportedTileNums = {8, 16, 32, 64, 128}; + + // Constructor that passes TensorView parameters to base constructor + Fp8PerTensorLauncher(TensorView const& routing_logits, Optional const& routing_bias, + TensorView const& hidden_states, TensorView const& gemm1_weights, + TensorView const& output1_scales_scalar, + TensorView const& output1_scales_gate_scalar, + TensorView const& gemm2_weights, TensorView const& output2_scales_scalar) + : FusedMoeLauncher(Optional(routing_logits), routing_bias, hidden_states, + gemm1_weights, Optional(output1_scales_scalar), + Optional(output1_scales_gate_scalar), gemm2_weights, + Optional(output2_scales_scalar)), + use_routing_scales_on_input(false) {} + + void init(std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, bool use_routing_scales_on_input_param) { + constexpr int64_t gated_act_type = + static_cast(GatedActType::SwiGlu); // not exposed in api for now + + this->use_routing_scales_on_input = use_routing_scales_on_input_param; + + auto dtype = hidden_states.dtype(); if (dtype == dl_float16) { - mDtypeElt = btg::Dtype::Fp16; + mDtypeAct = btg::Dtype::Fp16; } else if (dtype == dl_bfloat16) { - mDtypeElt = btg::Dtype::Bfloat16; + mDtypeAct = btg::Dtype::Bfloat16; } else if (dtype == dl_float8_e4m3fn) { - mDtypeElt = btg::Dtype::E4m3; + mDtypeAct = btg::Dtype::E4m3; } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for FP8 MoE."; } + mDtypeWeights = btg::Dtype::E4m3; - auto const num_tokens = hidden_states.size(0); - auto const hidden_size = hidden_states.size(1); - bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8 + FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, + use_shuffled_weight, weight_layout, gated_act_type); + } - std::vector mSupportedTileN = {8, 16, 32, 64, 128, 192, 256}; - std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + void check_routing() const override { FusedMoeLauncher::check_routing_common(); } - // Build runners for all supported tile sizes - std::unordered_map> mRunners; - for (int32_t tile_N : selected_tile_nums) { - // Always use the two-parameter constructor for consistency - mRunners.emplace(tile_N, std::make_unique(mDtypeElt, mUseDeepSeekFp8, tile_N, - /*useShuffledMatrixA*/ true)); - } + void prepare_routing() override { + FusedMoeLauncher::prepare_routing_common(); - // moeConfigIndex corresponds to pair (tile_N, config) - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - // Autotuner has requested a default or 'fallback' config index - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, - local_num_experts, num_tokens); + auto dtype = hidden_states.dtype(); + if (dtype == dl_float16) { + args->mDtypeElt = btg::Dtype::Fp16; + } else if (dtype == dl_bfloat16) { + args->mDtypeElt = btg::Dtype::Bfloat16; + } else if (dtype == dl_float8_e4m3fn) { + args->mDtypeElt = btg::Dtype::E4m3; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; } - trtllm_fp8_per_tensor_scale_moe_launcher( - routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, - output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, output, num_experts, - top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, - routed_scaling_factor, use_routing_scales_on_input, tile_N, routing_method_type, - *mRunners[tile_N], config, enable_pdl); - } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype."; - } -} + args->mDtypeOut = btg::Dtype::Bfloat16; + args->mUseDeepSeekFp8 = false; -void trtllm_fp8_block_scale_moe_launcher( - TensorView routing_logits, Optional routing_bias, TensorView hidden_states, - TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, - TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output, - int64_t const num_experts, int64_t const top_k, Optional const n_group, - Optional const topk_group, int64_t const intermediate_size, - int64_t const local_expert_offset, int64_t const local_num_experts, - Optional const routed_scaling_factor, int64_t const tile_tokens_dim, - int64_t const routing_method_type, - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex, - bool enable_pdl) { - static const std::tuple device_props = [hidden_states] { - int major, minor; - cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, - hidden_states.device().device_id); - cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, - hidden_states.device().device_id); - return std::make_tuple(major, minor); - }(); - - TVM_FFI_ICHECK_EQ(std::get<0>(device_props), 10) - << "This kernel requires 10.x architecture. Current device has SM " - << std::get<0>(device_props) << std::get<1>(device_props); + auto const routing_bias_dtype = + routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; + mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; - } else { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; + expert_weights = + alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + + workspace.expert_weights = expert_weights.data_ptr(); } - TVM_FFI_ICHECK_EQ(routing_logits.ndim(), 2) << "routing_logits must be 2D."; - TVM_FFI_ICHECK_EQ(routing_logits.size(0), hidden_states.size(0)) - << "routing_logits and hidden_states must have the same number of tokens."; - TVM_FFI_ICHECK_EQ(routing_logits.size(1), num_experts) - << "routing_logits dim1 must match num_experts."; - if (routing_bias.has_value()) { - TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || - routing_bias.value().dtype() == dl_float32) - << "routing_bias must be bfloat16 or float."; - TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; - TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts) - << "routing_bias has incorrect shape."; + + void check_moe() const override { + FusedMoeLauncher::check_moe_common(); + + TVM_FFI_ICHECK(output1_scales_scalar.has_value()) + << "output1_scales_scalar is required for FP8 MoE"; + TVM_FFI_ICHECK_EQ(output1_scales_scalar.value().dtype(), dl_float32) + << "output1_scales_scalar must be float."; + TVM_FFI_ICHECK_EQ(output1_scales_scalar.value().ndim(), 1) + << "output1_scales_scalar must be 1D."; + TVM_FFI_ICHECK_EQ(output1_scales_scalar.value().size(0), args->local_num_experts) + << "output1_scales_scalar has incorrect dim 0."; + + TVM_FFI_ICHECK(output1_scales_gate_scalar.has_value()) + << "output1_scales_gate_scalar is required for FP8 MoE"; + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value().dtype(), dl_float32) + << "output1_scales_gate_scalar must be float."; + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value().ndim(), 1) + << "output1_scales_gate_scalar must be 1D."; + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value().size(0), args->local_num_experts) + << "output1_scales_gate_scalar has incorrect dim 0."; + + TVM_FFI_ICHECK(output2_scales_scalar.has_value()) + << "output2_scales_scalar is required for FP8 MoE"; + TVM_FFI_ICHECK_EQ(output2_scales_scalar.value().dtype(), dl_float32) + << "output2_scales_scalar must be float."; + TVM_FFI_ICHECK_EQ(output2_scales_scalar.value().ndim(), 1) + << "output2_scales_scalar must be 1D."; + TVM_FFI_ICHECK_EQ(output2_scales_scalar.value().size(0), args->local_num_experts) + << "output2_scales_scalar has incorrect dim 0."; + + TVM_FFI_ICHECK(hidden_states.dtype() == dl_float8_e4m3fn || + hidden_states.dtype() == dl_float16 || hidden_states.dtype() == dl_bfloat16) + << "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16."; + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) + << "FP8 MoE: gemm1_weights must be float8_e4m3fn."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) + << "FP8 MoE: gemm2_weights must be float8_e4m3fn."; } - if (n_group.has_value() && n_group.value() != 0) { - TVM_FFI_ICHECK(static_cast(routing_method_type) == - RoutingMethodType::DeepSeekV3) - << "Routing kernel with groups implies DeepSeekV3 routing method."; - TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given"; - TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0) - << "num_experts must be divisible by n_group"; - TVM_FFI_ICHECK(top_k <= 8 && top_k > 0) - << "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."; - TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0) - << "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."; - TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value()) - << "n_group must not be smaller than topk_group."; - // This check ensures we have enough experts in the selected groups to handle the top_k routing - TVM_FFI_ICHECK_LT(top_k, (topk_group.value() * num_experts / n_group.value())) - << "top_k must be less than total number of experts in selected groups"; - } else if (static_cast(routing_method_type) == - RoutingMethodType::Renormalize || - static_cast(routing_method_type) == - RoutingMethodType::RenormalizeNaive) { - TVM_FFI_ICHECK(top_k <= 10 && top_k > 0) - << "Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0."; - } else if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { - TVM_FFI_ICHECK_EQ(top_k, 1) - << "Current routing kernel (no groups, Llama4) only supports top_k=1."; + void prepare_moe(int64_t& moe_tactic) override { + FusedMoeLauncher::prepare_moe_common(moe_tactic); + + int32_t max_num_padded_tokens_gemm1 = workspace.total_max_padded_tokens + args->num_experts; + int32_t max_num_padded_tokens_gemm2 = workspace.total_max_padded_tokens; + + gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * args->intermediate_size}, + dl_uint8, hidden_states.device()); + gemm1_output_scale = + alloc_tensor({2 * args->intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, + hidden_states.device()); + + activation_output = alloc_tensor({max_num_padded_tokens_gemm1, args->intermediate_size}, + dl_uint8, hidden_states.device()); + activation_output_scale = + alloc_tensor({args->intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, + hidden_states.device()); + + gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args->hidden_size}, dl_bfloat16, + hidden_states.device()); + + workspace.hidden_states_scale_linear = nullptr; + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale = static_cast(gemm1_output_scale.data_ptr()); + workspace.activation_output = activation_output.data_ptr(); + workspace.activation_output_scale = static_cast(activation_output_scale.data_ptr()); + workspace.gemm2_output = gemm2_output.data_ptr(); + workspace.gemm2_output_scale = nullptr; + + output = + alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device()); + args->output = output.data_ptr(); + args->output_scale = nullptr; + args->do_finalize = true; // FP8 per-tensor scale always finalizes + + // Set scale pointers + TVM_FFI_ICHECK(output1_scales_scalar.has_value()); + TVM_FFI_ICHECK(output1_scales_gate_scalar.has_value()); + TVM_FFI_ICHECK(output2_scales_scalar.has_value()); + + args->output1_scales_scalar = static_cast(output1_scales_scalar.value().data_ptr()); + args->output1_scales_gate_scalar = + static_cast(output1_scales_gate_scalar.value().data_ptr()); + args->output2_scales_scalar = static_cast(output2_scales_scalar.value().data_ptr()); } - TVM_FFI_ICHECK_EQ(num_experts % 4, 0) - << "Routing kernel expects that num_experts must be divisible by 4"; - TVM_FFI_ICHECK_GT(num_experts, top_k) << "num_experts must be greater than top_k"; - tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs args; - tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace; + private: + bool use_routing_scales_on_input; + Tensor gemm1_output_scale; + Tensor activation_output_scale; - // Convert PyTorch dtype to TensorRT-LLM dtype - auto dtype = hidden_states.dtype(); - if (dtype == dl_float16) { - args.mDtypeElt = btg::Dtype::Fp16; - } else if (dtype == dl_bfloat16) { - args.mDtypeElt = btg::Dtype::Bfloat16; - } else if (dtype == dl_float8_e4m3fn) { - args.mDtypeElt = btg::Dtype::E4m3; - } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; - } + public: + static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, + int64_t intermediate_size, int64_t num_local_experts, + int64_t num_tokens, int64_t gated_act_type, + bool use_shuffled_weight, int64_t weight_layout, + btg::Dtype dtype_act, btg::Dtype dtype_weights) { + Array> valid_configs; - auto const routing_bias_dtype = - routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; - auto btg_routing_bias_dtype = - routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; - - args.routing_logits = static_cast(routing_logits.data_ptr()); - args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; - args.hidden_states = hidden_states.data_ptr(); - args.hidden_states_scale = static_cast(hidden_states_scale.data_ptr()); - args.gemm1_weights = gemm1_weights.data_ptr(); - args.gemm1_weights_scale = static_cast(gemm1_weights_scale.data_ptr()); - args.gemm2_weights = gemm2_weights.data_ptr(); - args.gemm2_weights_scale = static_cast(gemm2_weights_scale.data_ptr()); - args.num_tokens = hidden_states.size(0); - args.num_experts = num_experts; - args.hidden_size = hidden_states.size(1); - args.hidden_size_output = args.hidden_size; - args.top_k = top_k; - args.n_group = n_group.has_value() ? n_group.value() : 0; - args.topk_group = topk_group.has_value() ? topk_group.value() : 0; - args.local_expert_offset = local_expert_offset; - args.local_num_experts = local_num_experts; - args.routed_scaling_factor = - routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0; - args.intermediate_size = intermediate_size; - args.mUseDeepSeekFp8 = true; - - // allocate workspace for routing kernel - Tensor num_tokens_per_expert = alloc_tensor({num_experts}, dl_int32, routing_logits.device()); - int32_t max_num_padded_tokens = - tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( - args.num_tokens, top_k, num_experts, tile_tokens_dim); - int32_t max_num_padded_tokens_gemm1 = - tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( - max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt)); - int32_t max_num_padded_tokens_gemm2 = - tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( - max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); - Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits.device()); - Tensor expanded_idx_to_permuted_idx = - alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits.device()); - Tensor permuted_idx_to_token_idx = - alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits.device()); - - Tensor expert_weights = - alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits.device()); - // NOTE: the output type of routing kernel is currently always bfloat16 - Tensor expert_indexes = - alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits.device()); - int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); - Tensor expert_count_histogram = alloc_tensor( - {size_of_expert_count_histogram}, - dl_int32, // 256 is the max number of threads per block and max number of experts - routing_logits.device()); - - // allocate workspace for activation/gemm/finalize kernels - Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8, - hidden_states.device()); - Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens}, - dl_float32, hidden_states.device()); - Tensor activation_output = alloc_tensor({max_num_padded_tokens_gemm1, intermediate_size}, - dl_uint8, hidden_states.device()); - Tensor activation_output_scale = alloc_tensor( - {intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states.device()); - Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16, - hidden_states.device()); + std::vector supported_tile_nums(mSupportedTileNums.begin(), mSupportedTileNums.end()); + std::set selected_tile_nums = + computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); - int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( - args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); - Tensor cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, routing_logits.device()); - Tensor cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, routing_logits.device()); - Tensor num_non_exiting_ctas = alloc_tensor({1}, dl_int32, routing_logits.device()); - - tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); - cudaStream_t stream = get_stream(routing_logits.device()); - routing_runner.run(static_cast(routing_logits.data_ptr()), args.routing_bias, - args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, - args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, - static_cast(expert_indexes.data_ptr()), - static_cast(expert_count_histogram.data_ptr()), - static_cast(total_num_padded_tokens.data_ptr()), - static_cast(expanded_idx_to_permuted_idx.data_ptr()), - nullptr /*static_cast(permuted_idx_to_expanded_idx.data_ptr())*/, - static_cast(permuted_idx_to_token_idx.data_ptr()), - expert_weights.data_ptr(), static_cast(num_tokens_per_expert.data_ptr()), - static_cast(cta_idx_xy_to_batch_idx.data_ptr()), - static_cast(cta_idx_xy_to_mn_limit.data_ptr()), - static_cast(num_non_exiting_ctas.data_ptr()), args.mDtypeElt, - btg_routing_bias_dtype, false /* use_routing_scales_on_input */, - true /* use_deep_seek_fp8 */, - static_cast(routing_method_type), stream); - - // MoE kernel except routing - TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; - TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) - << "hidden_states_scale must be float."; - TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; - TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) - << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; - TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args.num_tokens) - << "hidden_states_scale dim1 must match num_tokens."; - TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; - - TVM_FFI_ICHECK(gemm1_weights.ndim() == 3 || gemm1_weights.ndim() == 4) - << "gemm1_weights must be 3D or 4D."; - { - int64_t Mn = 0, K = 0; - if (gemm1_weights.ndim() == 3) { - // MajorK [num_experts, M, K] - Mn = gemm1_weights.size(1); - K = gemm1_weights.size(2); - } else if (gemm1_weights.ndim() == 4) { - // BlockMajorK [num_experts, K/block_k, M, block_k] - Mn = gemm1_weights.size(2); - int64_t block_k = gemm1_weights.size(3); - K = gemm1_weights.size(1) * block_k; + for (int32_t tile_N : selected_tile_nums) { + auto moe_runner = std::make_unique( + dtype_act, dtype_weights, + false, // useDeepSeekFp8 + tile_N, static_cast(gated_act_type), use_shuffled_weight, + static_cast(weight_layout)); + + auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); + + for (auto cfg : cfgs) { + valid_configs.push_back({tile_N, cfg}); + } } - TVM_FFI_ICHECK_EQ(Mn % 2, 0) << "the second dimension of weights must be even."; - TVM_FFI_ICHECK_EQ(intermediate_size, Mn / 2) << "intermediate_size has incorrect shape."; - TVM_FFI_ICHECK_EQ(K, hidden_states.size(1)) - << "the third dimension of weights must be equal to hidden_size."; + + return valid_configs; } - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) - << "gemm1_weights_scale must be float."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; - - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), local_num_experts) - << "gemm1_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(intermediate_size % 128, 0) - << "the second dimension of weights must be a multiple of 128."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * intermediate_size / 128) - << "gemm1_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args.hidden_size / 128) - << "gemm1_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; - - TVM_FFI_ICHECK(gemm2_weights.ndim() == 3 || gemm2_weights.ndim() == 4) - << "gemm2_weights must be 3D or 4D."; - { - int64_t K = 0; - if (gemm2_weights.ndim() == 3) { - // MajorK [num_experts, M, K] - K = gemm2_weights.size(2); - } else if (gemm2_weights.ndim() == 4) { - // BlockMajorK [num_experts, K/block_k, M, block_k] - int64_t block_k = gemm2_weights.size(3); - K = gemm2_weights.size(1) * block_k; +}; + +class Fp8BlockScaleLauncher : public FusedMoeLauncher { + public: + static constexpr std::array mSupportedTileNums = {8, 16, 32, 64, 128}; + + Fp8BlockScaleLauncher(TensorView const& routing_logits, Optional const& routing_bias, + TensorView const& hidden_states, TensorView const& hidden_states_scale, + TensorView const& gemm1_weights, TensorView const& gemm1_weights_scale, + TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale) + : FusedMoeLauncher(Optional(routing_logits), routing_bias, hidden_states, + gemm1_weights, Optional(), Optional(), + gemm2_weights, Optional()), + hidden_states_scale(hidden_states_scale), + gemm1_weights_scale(gemm1_weights_scale), + gemm2_weights_scale(gemm2_weights_scale) {} + + void init(std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout) { + constexpr int64_t gated_act_type = static_cast(GatedActType::SwiGlu); + + mDtypeAct = btg::Dtype::E4m3; + mDtypeWeights = btg::Dtype::E4m3; + + auto dtype = hidden_states.dtype(); + if (dtype == dl_float16) { + args->mDtypeElt = btg::Dtype::Fp16; + } else if (dtype == dl_bfloat16) { + args->mDtypeElt = btg::Dtype::Bfloat16; + } else if (dtype == dl_float8_e4m3fn) { + args->mDtypeElt = btg::Dtype::E4m3; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; } - TVM_FFI_ICHECK_EQ(K, intermediate_size) - << "the third dimension of weights must be equal to intermediate_size."; + + // Output is always bfloat16 for FP8 block scale + args->mDtypeOut = btg::Dtype::Bfloat16; + + FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, + use_shuffled_weight, weight_layout, gated_act_type); } - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) - << "gemm2_weights_scale must be float."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), local_num_experts) - << "gemm2_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args.hidden_size / 128) - << "gemm2_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), intermediate_size / 128) - << "gemm2_weights_scale has incorrect shape."; - - TVM_FFI_ICHECK_EQ(output.size(0), args.num_tokens) << "output has incorrect shape."; - TVM_FFI_ICHECK_EQ(output.size(1), args.hidden_size) << "output has incorrect shape."; - TVM_FFI_ICHECK_EQ(output.dtype(), dl_bfloat16) << "output must be bf16."; - - // setup workspace - workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); - workspace.total_max_padded_tokens = - std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); - workspace.ProjUpTileN = tile_tokens_dim; - workspace.routing_expert_indexes = static_cast(expert_indexes.data_ptr()); - workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); - workspace.expanded_idx_to_permuted_idx = static_cast( - expanded_idx_to_permuted_idx.data_ptr()); // Needed by activation/finalize kernels - workspace.permuted_idx_to_token_idx = - static_cast(permuted_idx_to_token_idx.data_ptr()); // Needed by permuteGemm1 kernel - workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel - - workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx.data_ptr()); - workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit.data_ptr()); - workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas.data_ptr()); - - // gemm1 intermediate ws - workspace.gemm1_output = gemm1_output.data_ptr(); - workspace.gemm1_output_scale = static_cast(gemm1_output_scale.data_ptr()); - // activation intermediate ws - workspace.activation_output = activation_output.data_ptr(); - workspace.activation_output_scale = static_cast(activation_output_scale.data_ptr()); - // gemm2 intermediate ws - workspace.gemm2_output = gemm2_output.data_ptr(); - workspace.gemm2_output_scale = nullptr; - args.output = output.data_ptr(); - args.output_scale = nullptr; - - auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); - Tensor workspace_fc1 = - alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device()); - Tensor workspace_fc2 = - alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states.device()); - workspace.bmm1_workspace = workspace_fc1.data_ptr(); - workspace.bmm2_workspace = workspace_fc2.data_ptr(); - - cudaStream_t moe_stream = get_stream(hidden_states.device()); - moe_runner.run(args, workspace, hidden_states.device().device_id, moe_stream, moeConfigIndex, - enable_pdl); -} -void trtllm_fp8_block_scale_moe( - TensorView routing_logits, Optional routing_bias, TensorView hidden_states, - TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, - TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output, - int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, - int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, - Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, bool enable_pdl, Array config_index) { - auto dtype = hidden_states.dtype(); - if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) { - using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; + void check_routing() const override { + FusedMoeLauncher::check_routing_common(); + + if (args->n_group != 0) { + TVM_FFI_ICHECK(static_cast(routing_method_type) == + RoutingMethodType::DeepSeekV3) + << "Routing kernel with groups implies DeepSeekV3 routing method."; + TVM_FFI_ICHECK(args->topk_group != 0) << "if n_group is given, topk_group must be given"; + TVM_FFI_ICHECK_EQ(args->num_experts % args->n_group, 0) + << "num_experts must be divisible by n_group"; + TVM_FFI_ICHECK(args->top_k <= 8 && args->top_k > 0) + << "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."; + TVM_FFI_ICHECK(args->topk_group <= 4 && args->topk_group > 0) + << "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."; + TVM_FFI_ICHECK_LE(args->topk_group, args->n_group) + << "n_group must not be smaller than topk_group."; + TVM_FFI_ICHECK_LT(args->top_k, (args->topk_group * args->num_experts / args->n_group)) + << "top_k must be less than total number of experts in selected groups"; + } else if (static_cast(routing_method_type) == + RoutingMethodType::Renormalize || + static_cast(routing_method_type) == + RoutingMethodType::RenormalizeNaive) { + TVM_FFI_ICHECK(args->top_k <= 10 && args->top_k > 0) + << "Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0."; + } else if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { + TVM_FFI_ICHECK_EQ(args->top_k, 1) + << "Current routing kernel (no groups, Llama4) only supports top_k=1."; + } + TVM_FFI_ICHECK_EQ(args->num_experts % 4, 0) + << "Routing kernel expects that num_experts must be divisible by 4"; + TVM_FFI_ICHECK_GT(args->num_experts, args->top_k) << "num_experts must be greater than top_k"; + TVM_FFI_ICHECK_LE(args->local_num_experts + args->local_expert_offset, args->num_experts) + << "num_experts must be greater or equal to local_num_experts + local_expert_offset"; + } - btg::Dtype mDtypeElt{btg::Dtype::E4m3}; // FP8 runner so hard-coded - bool mUseDeepSeekFp8{true}; // Always true for BlockScaleMoe + void prepare_routing() override { + FusedMoeLauncher::prepare_routing_common(); - TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) - << "the value of weight_layout is not recognized"; + auto dtype = hidden_states.dtype(); + if (dtype == dl_float16) { + args->mDtypeElt = btg::Dtype::Fp16; + } else if (dtype == dl_bfloat16) { + args->mDtypeElt = btg::Dtype::Bfloat16; + } else if (dtype == dl_float8_e4m3fn) { + args->mDtypeElt = btg::Dtype::E4m3; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; + } + + args->mUseDeepSeekFp8 = true; + args->routing_logits = static_cast(routing_logits.value().data_ptr()); + // Set expert weights dtype based on routing bias + auto const routing_bias_dtype = + routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; + mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + + expert_weights = + alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + workspace.expert_weights = expert_weights.data_ptr(); + } - auto const num_tokens = hidden_states.size(0); - auto const hidden_size = hidden_states.size(1); + void check_moe() const override { + FusedMoeLauncher::check_moe_common(); - std::vector mSupportedTileN = {8, 16, 32, 64, 128}; + TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) + << "hidden_states_scale must be float."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) + << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens) + << "hidden_states_scale dim1 must match num_tokens."; + + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; + + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) + << "gemm1_weights_scale must be float."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) + << "gemm1_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) + << "intermediate_size must be a multiple of 128."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128) + << "gemm1_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128) + << "gemm1_weights_scale has incorrect shape."; + + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) + << "gemm2_weights_scale must be float."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) + << "gemm2_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) + << "gemm2_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) + << "gemm2_weights_scale has incorrect shape."; + + check_weights_shape("gemm1"); + check_weights_shape("gemm2"); + TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) + << "intermediate_size must be a multiple of 128."; + } + + void prepare_moe(int64_t& moe_tactic) override { + FusedMoeLauncher::prepare_moe_common(moe_tactic); + + // Calculate max_num_padded_tokens for gemm1 and gemm2 using maybeGetMinTokenCount + int32_t max_num_padded_tokens_gemm1 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + workspace.total_max_padded_tokens, args->intermediate_size, + btg::dtypeGetNumBits(args->mDtypeElt)); + int32_t max_num_padded_tokens_gemm2 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + workspace.total_max_padded_tokens, args->hidden_size, + btg::dtypeGetNumBits(args->mDtypeOut)); + + gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * args->intermediate_size}, + dl_uint8, hidden_states.device()); + gemm1_output_scale = + alloc_tensor({2 * args->intermediate_size / 128, workspace.total_max_padded_tokens}, + dl_float32, hidden_states.device()); + + activation_output = alloc_tensor({max_num_padded_tokens_gemm1, args->intermediate_size}, + dl_uint8, hidden_states.device()); + activation_output_scale = + alloc_tensor({args->intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, + hidden_states.device()); + + gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args->hidden_size}, dl_bfloat16, + hidden_states.device()); + + workspace.hidden_states_scale_linear = nullptr; + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale = static_cast(gemm1_output_scale.data_ptr()); + workspace.activation_output = activation_output.data_ptr(); + workspace.activation_output_scale = static_cast(activation_output_scale.data_ptr()); + workspace.gemm2_output = gemm2_output.data_ptr(); + workspace.gemm2_output_scale = nullptr; + + output = + alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device()); + args->output = output.data_ptr(); + args->output_scale = nullptr; + args->do_finalize = true; + + args->hidden_states_scale = static_cast(hidden_states_scale.data_ptr()); + args->gemm1_weights_scale = static_cast(gemm1_weights_scale.data_ptr()); + args->gemm2_weights_scale = static_cast(gemm2_weights_scale.data_ptr()); + } + + private: + TensorView hidden_states_scale; + TensorView gemm1_weights_scale; + TensorView gemm2_weights_scale; + Tensor gemm1_output_scale; + Tensor activation_output_scale; + + public: + static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, + int64_t intermediate_size, int64_t num_local_experts, + int64_t num_tokens, bool use_shuffled_weight, + int64_t weight_layout, btg::Dtype dtype_weights) { + Array> valid_configs; + + std::vector supported_tile_nums(mSupportedTileNums.begin(), mSupportedTileNums.end()); std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); - // Build runners for all supported tile sizes - std::unordered_map> mRunners; for (int32_t tile_N : selected_tile_nums) { - mRunners.emplace(tile_N, std::make_unique( - mDtypeElt, mUseDeepSeekFp8, tile_N, use_shuffled_weight, - static_cast(weight_layout))); - } + auto moe_runner = std::make_unique( + dtype_weights, // dtype_weights for DeepSeek FP8 + true, // useDeepSeekFp8 + tile_N, use_shuffled_weight, static_cast(weight_layout)); - // moeConfigIndex corresponds to pair (tile_N, config) - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - // Autotuner has requested a default or 'fallback' config index - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, - local_num_experts, num_tokens); + auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); + + for (auto cfg : cfgs) { + valid_configs.push_back({tile_N, cfg}); + } } - trtllm_fp8_block_scale_moe_launcher( - routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, - gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output, num_experts, top_k, - n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, - routed_scaling_factor, tile_N, routing_method_type, *mRunners[tile_N], config, enable_pdl); - } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported hidden state dtype."; + return valid_configs; } -} +}; -// TODO(siyuan): This launcher supports flexible weight and activation types. -// We should cleanup other launchers and only use this one in the future. -Array trtllm_fp4_block_scale_moe_launcher( - Optional routing_logits, TensorView expert_indices, TensorView expert_weights, - Optional routing_bias, TensorView hidden_states, - Optional hidden_states_scale, TensorView gemm1_weights, - TensorView gemm1_weights_scale, Optional gemm1_bias, - Optional gemm1_alpha, Optional gemm1_beta, - Optional gemm1_clamp_limit, TensorView gemm2_weights, - TensorView gemm2_weights_scale, Optional gemm2_bias, - Optional output1_scales_scalar, Optional output1_scales_gate_scalar, - Optional output2_scales_scalar, int64_t const num_experts, int64_t const top_k, - Optional const n_group, Optional const topk_group, - int64_t const intermediate_size, int64_t const local_expert_offset, - int64_t const local_num_experts, Optional const routed_scaling_factor, - int64_t const tile_tokens_dim, int64_t const routing_method_type, bool const do_finalize, - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, btg::Dtype dtype_act, - btg::Dtype dtype_weights, int64_t const moeConfigIndex, bool enable_pdl, TensorView output) { - static const std::tuple device_props = [hidden_states] { - int major, minor; - cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, - hidden_states.device().device_id); - cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, - hidden_states.device().device_id); - return std::make_tuple(major, minor); - }(); - - TVM_FFI_ICHECK_EQ(std::get<0>(device_props), 10) - << "This kernel requires 10.x architecture. Current device has SM " - << std::get<0>(device_props) << std::get<1>(device_props); - - TVM_FFI_ICHECK(dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::Bfloat16 || - dtype_act == btg::Dtype::E4m3 || dtype_act == btg::Dtype::MxE4m3) - << "Only E2m1, Bfloat16, MxE4m3 and E4m3 are supported by block scale MoE"; - if (dtype_act == btg::Dtype::E2m1) { - TVM_FFI_ICHECK(dtype_weights == btg::Dtype::E2m1) - << "Only E2m1 and MxE2m1 are supported by block scale MoE with E2m1 activation"; - TVM_FFI_ICHECK(hidden_states_scale.has_value()) - << "hidden_states_scale is required for E2m1 activation"; - TVM_FFI_ICHECK(output1_scales_scalar.has_value()) - << "output1_scales_scalar is required for E2m1 activation"; - TVM_FFI_ICHECK(output1_scales_gate_scalar.has_value()) - << "output1_scales_gate_scalar is required for E2m1 activation"; - TVM_FFI_ICHECK(output2_scales_scalar.has_value()) - << "output2_scales_scalar is required for E2m1 activation"; - } else if (dtype_act == btg::Dtype::Bfloat16 || dtype_act == btg::Dtype::E4m3 || - dtype_act == btg::Dtype::MxE4m3) { - TVM_FFI_ICHECK(dtype_weights == btg::Dtype::MxE2m1) - << "Only MxE2m1 weights are supported by block scale MoE with Bfloat16, E4m3 or " - "MxE4m3 activation"; - } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported act dtype."; - } +class FP4BlockScaleLauncher : public FusedMoeLauncher { + public: + static constexpr std::array mBaseSupportedTileNums = {8, 16, 32, 64}; - if (dtype_act == btg::Dtype::E4m3) { - TVM_FFI_ICHECK(output1_scales_scalar.has_value()) - << "output1_scales_scalar is required for E4m3 activation"; - TVM_FFI_ICHECK(output1_scales_gate_scalar.has_value()) - << "output1_scales_gate_scalar is required for E4m3 activation"; - TVM_FFI_ICHECK(output2_scales_scalar.has_value()) - << "output2_scales_scalar is required for E4m3 activation"; + static std::vector getSupportedTileNums(btg::Dtype dtype_act) { + std::vector tiles(mBaseSupportedTileNums.begin(), mBaseSupportedTileNums.end()); + if (dtype_act != btg::Dtype::Bfloat16) { + tiles.push_back(128); + tiles.push_back(256); + } + return tiles; } - if (routing_logits.has_value()) { - TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || - routing_logits.value().dtype() == dl_bfloat16) - << "routing_logits must be float or bfloat16."; - TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; - TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts) - << "routing_logits has incorrect shape."; + FP4BlockScaleLauncher( + Optional const& routing_logits, Optional const& routing_bias, + TensorView const& hidden_states, Optional const& hidden_states_scale, + TensorView const& gemm1_weights, TensorView const& gemm1_weights_scale, + Optional const& gemm1_bias, Optional const& gemm1_alpha, + Optional const& gemm1_beta, Optional const& gemm1_clamp_limit, + TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale, + Optional const& gemm2_bias, Optional const& output1_scales_scalar, + Optional const& output1_scales_gate_scalar, + Optional const& output2_scales_scalar, TensorView const& expert_indices, + TensorView const& expert_weights) + : FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights, + output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, + output2_scales_scalar), + hidden_states_scale(hidden_states_scale), + gemm1_weights_scale(gemm1_weights_scale), + gemm1_bias(gemm1_bias), + gemm1_alpha(gemm1_alpha), + gemm1_beta(gemm1_beta), + gemm1_clamp_limit(gemm1_clamp_limit), + gemm2_weights_scale(gemm2_weights_scale), + gemm2_bias(gemm2_bias), + expert_indices(expert_indices), + expert_weights(expert_weights) {} + + void init(std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, int64_t gated_act_type, btg::Dtype dtype_act, + btg::Dtype dtype_weights) { + static const std::tuple device_props = [this] { + int major, minor; + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, + hidden_states.device().device_id); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, + hidden_states.device().device_id); + return std::make_tuple(major, minor); + }(); + + TVM_FFI_ICHECK_EQ(std::get<0>(device_props), 10) + << "This kernel requires 10.x architecture. Current device has SM " + << std::get<0>(device_props) << std::get<1>(device_props); + + // Set data types + args->mDtypeElt = dtype_act; + args->mDtypeOut = btg::Dtype::Bfloat16; // Output is always BF16 for FP4 + args->mUseDeepSeekFp8 = false; // FP4 doesn't use DeepSeek FP8 + + mDtypeAct = dtype_act; + mDtypeWeights = dtype_weights; + + FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, + use_shuffled_weight, weight_layout, gated_act_type); } - if (routing_bias.has_value()) { - TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || - routing_bias.value().dtype() == dl_float32) - << "routing_bias must be bfloat16 or float."; - TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; - TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts) - << "routing_bias has incorrect shape."; + void check_routing() const override { + // First call base class common routing checks + FusedMoeLauncher::check_routing_common(); } - if (n_group.value_or(0) != 0) { - TVM_FFI_ICHECK(static_cast(routing_method_type) == - RoutingMethodType::DeepSeekV3) - << "Routing kernel with groups implies DeepSeekV3 routing method."; - TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given"; - TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0) - << "num_experts must be divisible by n_group"; - TVM_FFI_ICHECK(top_k <= 10 && top_k > 0) - << "Current routing kernel (with groups) only supports top_k<=10 && top_k>0."; - TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0) - << "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."; - TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value()) - << "n_group must not be smaller than topk_group."; - // This check ensures we have enough experts in the selected groups to handle the top_k routing - TVM_FFI_ICHECK_LT(top_k, (topk_group.value() * num_experts / n_group.value())) - << "top_k must be less than total number of experts in selected groups"; - } else if (static_cast(routing_method_type) == - RoutingMethodType::Renormalize || - static_cast(routing_method_type) == - RoutingMethodType::RenormalizeNaive || - static_cast(routing_method_type) == RoutingMethodType::TopK) { - TVM_FFI_ICHECK(top_k <= 10 && top_k > 0) - << "Current routing kernel (no groups, renormalize/topk) only supports top_k<=10 && " - "top_k>0."; - } else if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { - TVM_FFI_ICHECK_EQ(top_k, 1) - << "Current routing kernel (no groups, Llama4) only supports top_k=1."; + void prepare_routing() override { + num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, hidden_states.device()); + int32_t max_num_padded_tokens = + tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( + args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + + total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states.device()); + expanded_idx_to_permuted_idx = + alloc_tensor({args->num_tokens * args->top_k}, dl_int32, hidden_states.device()); + permuted_idx_to_token_idx = + alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device()); + + int64_t const size_of_expert_count_histogram = std::max(args->num_experts * 2, 256 * 2); + expert_count_histogram = + alloc_tensor({size_of_expert_count_histogram}, dl_int32, hidden_states.device()); + + int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( + args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); + cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); + num_non_exiting_ctas = alloc_tensor({1}, dl_int32, hidden_states.device()); + + workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); + workspace.total_max_padded_tokens = max_num_padded_tokens; + workspace.ProjUpTileN = tile_tokens_dim; + workspace.routing_expert_indexes = + static_cast(const_cast(expert_indices.data_ptr())); + workspace.expert_weights = const_cast(expert_weights.data_ptr()); + workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); + workspace.expanded_idx_to_permuted_idx = + static_cast(expanded_idx_to_permuted_idx.data_ptr()); + workspace.permuted_idx_to_token_idx = static_cast(permuted_idx_to_token_idx.data_ptr()); + workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx.data_ptr()); + workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit.data_ptr()); + workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas.data_ptr()); + + args->mDtypeElt = mDtypeAct; + auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; + mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; } - TVM_FFI_ICHECK_EQ(num_experts % 4, 0) - << "Routing kernel expects that num_experts must be divisible by 4"; - TVM_FFI_ICHECK_GT(num_experts, top_k) << "num_experts must be greater than top_k"; - - tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs args; - tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace; + void check_moe() const override { + TVM_FFI_ICHECK(mDtypeAct == btg::Dtype::E2m1 || mDtypeAct == btg::Dtype::Bfloat16 || + mDtypeAct == btg::Dtype::E4m3 || mDtypeAct == btg::Dtype::MxE4m3) + << "Only E2m1, Bfloat16, MxE4m3 and E4m3 are supported by block scale MoE"; + + if (mDtypeAct == btg::Dtype::E2m1) { + TVM_FFI_ICHECK(mDtypeWeights == btg::Dtype::E2m1) + << "Only E2m1 and MxE2m1 are supported by block scale MoE with E2m1 activation"; + TVM_FFI_ICHECK(hidden_states_scale.has_value()) + << "hidden_states_scale is required for E2m1 activation"; + TVM_FFI_ICHECK(output1_scales_scalar.has_value()) + << "output1_scales_scalar is required for E2m1 activation"; + TVM_FFI_ICHECK(output1_scales_gate_scalar.has_value()) + << "output1_scales_gate_scalar is required for E2m1 activation"; + TVM_FFI_ICHECK(output2_scales_scalar.has_value()) + << "output2_scales_scalar is required for E2m1 activation"; + } else if (mDtypeAct == btg::Dtype::Bfloat16 || mDtypeAct == btg::Dtype::E4m3 || + mDtypeAct == btg::Dtype::MxE4m3) { + TVM_FFI_ICHECK(mDtypeWeights == btg::Dtype::MxE2m1) + << "Only MxE2m1 weights are supported by block scale MoE with Bfloat16, E4m3 or " + "MxE4m3 activation"; + } - // setup args - args.mDtypeElt = dtype_act; - // note: the assumption is that output data type is always Bfloat16 (the default) - auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; - auto btg_routing_bias_dtype = - routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; - // We shouln't use args.mDtypeExpW since it indicates the output data type of routing kernel, - // which is currently always bfloat16 for routing kernel while the data type of routing bias now - // can be fp32 - args.routing_logits = routing_logits.has_value() ? routing_logits.value().data_ptr() : nullptr; - args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; - args.hidden_states = hidden_states.data_ptr(); - args.hidden_states_scale = - hidden_states_scale.has_value() ? hidden_states_scale.value().data_ptr() : nullptr; - args.gemm1_weights = gemm1_weights.data_ptr(); - args.gemm1_weights_scale = gemm1_weights_scale.data_ptr(); - args.gemm1_bias = - gemm1_bias.has_value() ? static_cast(gemm1_bias.value().data_ptr()) : nullptr; - args.gemm1_alpha = - gemm1_alpha.has_value() ? static_cast(gemm1_alpha.value().data_ptr()) : nullptr; - args.gemm1_beta = - gemm1_beta.has_value() ? static_cast(gemm1_beta.value().data_ptr()) : nullptr; - args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() - ? static_cast(gemm1_clamp_limit.value().data_ptr()) - : nullptr; - args.gemm2_weights = gemm2_weights.data_ptr(); - args.gemm2_weights_scale = gemm2_weights_scale.data_ptr(); - args.gemm2_bias = - gemm2_bias.has_value() ? static_cast(gemm2_bias.value().data_ptr()) : nullptr; - args.num_tokens = hidden_states.size(0); - args.num_experts = num_experts; - // * 2 to compensate for the fact that sizeof(hidden_states.dtype) is 1 because we pack 2 e2m1 - // into 1 byte. - auto const hidden_states_hidden_size = - dtype_act == btg::Dtype::E2m1 ? hidden_states.size(1) * 2 : hidden_states.size(1); - args.hidden_size = hidden_states_hidden_size; - args.hidden_size_output = args.hidden_size; - args.top_k = top_k; - args.n_group = n_group.value_or(0); - args.topk_group = topk_group.value_or(0); - args.local_expert_offset = local_expert_offset; - args.local_num_experts = local_num_experts; - args.routed_scaling_factor = routed_scaling_factor.value_or(1.0); - args.intermediate_size = intermediate_size; - - // allocate workspace for routing kernel - Tensor num_tokens_per_expert = alloc_tensor({num_experts}, dl_int32, hidden_states.device()); - int32_t max_num_padded_tokens = - tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( - args.num_tokens, top_k, num_experts, tile_tokens_dim); - int32_t max_num_padded_tokens_gemm1 = - tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( - max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt)); - int32_t max_num_padded_tokens_gemm2 = - tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( - max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); - Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states.device()); - Tensor expanded_idx_to_permuted_idx = - alloc_tensor({args.num_tokens, args.top_k}, dl_int32, hidden_states.device()); - - Tensor permuted_idx_to_token_idx = - alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device()); - int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); - Tensor expert_count_histogram = - alloc_tensor({size_of_expert_count_histogram}, dl_int32, hidden_states.device()); - - auto const sf_vec_size = dtype_weights == btg::Dtype::MxE2m1 ? 32 : 16; - - // allocate workspace for activation/gemm/finalize kernels - auto const gemm1_output_hidden = - dtype_act == btg::Dtype::E2m1 ? intermediate_size / 2 : intermediate_size; - Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden}, - dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8, - hidden_states.device()); + if (mDtypeAct == btg::Dtype::E4m3) { + TVM_FFI_ICHECK(output1_scales_scalar.has_value()) + << "output1_scales_scalar is required for E4m3 activation"; + TVM_FFI_ICHECK(output1_scales_gate_scalar.has_value()) + << "output1_scales_gate_scalar is required for E4m3 activation"; + TVM_FFI_ICHECK(output2_scales_scalar.has_value()) + << "output2_scales_scalar is required for E4m3 activation"; + } - Optional gemm1_output_scale = std::nullopt; - if (dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::MxE4m3) { - int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens_gemm1, - intermediate_size / sf_vec_size); - // gemm1_output_scale = alloc_tensor({sf_size}, dl_float8_e4m3fn, hidden_states.device()); - gemm1_output_scale = alloc_tensor({sf_size}, dl_uint8, hidden_states.device()); + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_uint8) << "gemm1_weights must be byte."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float8_e4m3fn) + << "gemm1_weights_scale must be fp8."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_uint8) << "gemm2_weights must be byte."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float8_e4m3fn) + << "gemm2_weights_scale must be fp8."; } - Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16, - hidden_states.device()); + void prepare_moe(int64_t& moe_tactic) override { + args->hidden_states = hidden_states.data_ptr(); + args->hidden_states_scale = + hidden_states_scale.has_value() ? hidden_states_scale.value().data_ptr() : nullptr; + args->gemm1_weights = gemm1_weights.data_ptr(); + args->gemm1_weights_scale = gemm1_weights_scale.data_ptr(); + args->gemm1_bias = + gemm1_bias.has_value() ? static_cast(gemm1_bias.value().data_ptr()) : nullptr; + args->gemm1_alpha = + gemm1_alpha.has_value() ? static_cast(gemm1_alpha.value().data_ptr()) : nullptr; + args->gemm1_beta = + gemm1_beta.has_value() ? static_cast(gemm1_beta.value().data_ptr()) : nullptr; + args->gemm1_clamp_limit = gemm1_clamp_limit.has_value() + ? static_cast(gemm1_clamp_limit.value().data_ptr()) + : nullptr; + args->gemm2_weights = gemm2_weights.data_ptr(); + args->gemm2_weights_scale = gemm2_weights_scale.data_ptr(); + args->gemm2_bias = + gemm2_bias.has_value() ? static_cast(gemm2_bias.value().data_ptr()) : nullptr; + args->output1_scales_scalar = + output1_scales_scalar.has_value() + ? static_cast(output1_scales_scalar.value().data_ptr()) + : nullptr; + args->output1_scales_gate_scalar = + output1_scales_gate_scalar.has_value() + ? static_cast(output1_scales_gate_scalar.value().data_ptr()) + : nullptr; + args->output2_scales_scalar = + output2_scales_scalar.has_value() + ? static_cast(output2_scales_scalar.value().data_ptr()) + : nullptr; + + FusedMoeLauncher::prepare_moe_common(moe_tactic); + + auto const sf_vec_size = mDtypeWeights == btg::Dtype::MxE2m1 ? 32 : 16; + + max_num_padded_tokens_gemm1 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + workspace.total_max_padded_tokens, args->intermediate_size, + btg::dtypeGetNumBits(mDtypeAct)); + max_num_padded_tokens_gemm2 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + workspace.total_max_padded_tokens, args->hidden_size, + btg::dtypeGetNumBits(btg::Dtype::Bfloat16)); // Output is always BF16 + + auto const gemm1_output_hidden = + mDtypeAct == btg::Dtype::E2m1 ? args->intermediate_size / 2 : args->intermediate_size; + gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden}, + mDtypeAct == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8, + hidden_states.device()); + + if (mDtypeAct == btg::Dtype::E2m1 || mDtypeAct == btg::Dtype::MxE4m3) { + int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize( + max_num_padded_tokens_gemm1, args->intermediate_size / sf_vec_size); + gemm1_output_scale = alloc_tensor({sf_size}, dl_uint8, hidden_states.device()); + } - int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( - args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); - Tensor cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); - Tensor cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); - Tensor num_non_exiting_ctas = alloc_tensor({1}, dl_int32, hidden_states.device()); - - // - // TopK routing - // - - tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); - cudaStream_t stream = get_stream(hidden_states.device()); - routing_runner.run( - args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, - args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, - args.routed_scaling_factor, static_cast(expert_indices.data_ptr()), - static_cast(expert_count_histogram.data_ptr()), - static_cast(total_num_padded_tokens.data_ptr()), - static_cast(expanded_idx_to_permuted_idx.data_ptr()), - nullptr, /*static_cast(permuted_idx_to_expanded_idx.data_ptr()),*/ - static_cast(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), - static_cast(num_tokens_per_expert.data_ptr()), - static_cast(cta_idx_xy_to_batch_idx.data_ptr()), - static_cast(cta_idx_xy_to_mn_limit.data_ptr()), - static_cast(num_non_exiting_ctas.data_ptr()), args.mDtypeElt, btg_routing_bias_dtype, - false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, - static_cast(routing_method_type), stream); - - // - // FC13 (gemm1) + FC2 (gemm2) - // - - if (dtype_act == btg::Dtype::E2m1) { - TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_uint8) << "hidden_states must be byte."; - } else if (dtype_act == btg::Dtype::E4m3 || dtype_act == btg::Dtype::MxE4m3) { - TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; - } else if (dtype_act == btg::Dtype::Bfloat16) { - TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_bfloat16) << "hidden_states must be bfloat16."; - } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported act dtype."; + // Allocate gemm2_output + gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args->hidden_size}, dl_bfloat16, + hidden_states.device()); + + // Setup workspace pointers + workspace.hidden_states_scale_linear = nullptr; // FP4 doesn't use linear scale + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale = gemm1_output_scale.has_value() + ? static_cast(gemm1_output_scale.value().data_ptr()) + : nullptr; + // Note: activation_output and activation_output_scale are set by the base class + // prepare_moe_common() when gated activation is used + workspace.gemm2_output = gemm2_output.data_ptr(); + workspace.gemm2_output_scale = nullptr; } - if (hidden_states_scale.has_value()) { - TVM_FFI_ICHECK_EQ(hidden_states_scale.value().dtype(), dl_float8_e4m3fn) - << "hidden_states_scale must be fp8."; + private: + Optional hidden_states_scale; + TensorView gemm1_weights_scale; + Optional gemm1_bias; + Optional gemm1_alpha; + Optional gemm1_beta; + Optional gemm1_clamp_limit; + TensorView gemm2_weights_scale; + Optional gemm2_bias; + int32_t max_num_padded_tokens_gemm1{}; + int32_t max_num_padded_tokens_gemm2{}; + Optional gemm1_output_scale; + TensorView expert_indices; + TensorView expert_weights; + + public: + Array run(int64_t moe_tactic, bool enable_pdl = true, + bool use_routing_scales_on_input = false, + bool use_deep_seek_fp8 = false) override { + check_routing(); + prepare_routing(); + + // Execute routing + tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); + cudaStream_t routing_stream = get_stream(hidden_states.device()); + + routing_runner.run( + args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, + args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, + args->routed_scaling_factor, static_cast(expert_indices.data_ptr()), + static_cast(expert_count_histogram.data_ptr()), + static_cast(total_num_padded_tokens.data_ptr()), + static_cast(expanded_idx_to_permuted_idx.data_ptr()), + nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, + static_cast(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), + static_cast(num_tokens_per_expert.data_ptr()), + static_cast(cta_idx_xy_to_batch_idx.data_ptr()), + static_cast(cta_idx_xy_to_mn_limit.data_ptr()), + static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, + use_routing_scales_on_input, use_deep_seek_fp8, + static_cast(routing_method_type), routing_stream); + + check_moe(); + prepare_moe(moe_tactic); + + cudaStream_t moe_stream = get_stream(hidden_states.device()); + moe_runner->run(*args, workspace, hidden_states.device().device_id, moe_stream, moe_tactic, + enable_pdl); + + // Match original FP4 behavior for return values + if (args->do_finalize) { + return {}; + } + return {gemm2_output, expanded_idx_to_permuted_idx}; + } + + static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, + int64_t intermediate_size, int64_t num_local_experts, + int64_t num_tokens, int64_t gated_act_type, + btg::Dtype dtype_act, btg::Dtype dtype_weights) { + Array> valid_configs; + + std::vector tile_sizes = getSupportedTileNums(dtype_act); + std::set selected_tile_nums = + computeSelectedTileN(tile_sizes, num_tokens, top_k, num_local_experts); + + for (int32_t tile_N : selected_tile_nums) { + auto moe_runner = std::make_unique( + dtype_act, dtype_weights, + false, // useDeepSeekFp8 + tile_N, static_cast(gated_act_type), + /*useShuffledMatrixA*/ true); // FP4 uses shuffled weights + + auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); - TVM_FFI_ICHECK_EQ( - hidden_states_scale.value().numel(), - tensorrt_llm::computeLinearLayoutSFSize(args.num_tokens, args.hidden_size / sf_vec_size)) - << "hidden_states_scale has incorrect size"; + for (auto cfg : cfgs) { + valid_configs.push_back({tile_N, cfg}); + } + } + + return valid_configs; } +}; + +Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional const& routing_bias, + TensorView const& hidden_states, TensorView const& gemm1_weights, + TensorView const& gemm2_weights, int64_t num_experts, int64_t top_k, + Optional n_group, Optional topk_group, + int64_t intermediate_size, int64_t local_expert_offset, + int64_t local_num_experts, int64_t routing_method_type, + bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl, + Array moe_tactic) { + // Just some basic type validation first and leave more checks to the launcher + TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16) + << "BF16 MoE: routing_logits must be bfloat16 or float."; + TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_bfloat16) + << "BF16 MoE: hidden_states must be bfloat16."; + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_bfloat16) + << "BF16 MoE: gemm1_weights must be bfloat16."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_bfloat16) + << "BF16 MoE: gemm2_weights must be bfloat16."; + + auto const num_tokens = hidden_states.size(0); + auto const hidden_size = hidden_states.size(1); + + // Calculate supported tile sizes + std::vector mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(), + Bf16MoeLauncher::mSupportedTileNums.end()); + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); - TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_uint8) << "gemm1_weights must be byte."; - - TVM_FFI_ICHECK_EQ(gemm1_weights.ndim(), 3) << "gemm1_weights must be 3D."; - TVM_FFI_ICHECK_EQ(gemm1_weights.size(1) % 2, 0) - << "the second dimension of weights must be even."; - TVM_FFI_ICHECK_EQ(intermediate_size, gemm1_weights.size(1) / 2) - << "intermediate_size has incorrect dim 1."; - // This check passes even though the actual shape of the weights[2] and hidden_states[1] is - // 2 times larger due to the fact that 2 e2m1 are packed into 1 byte. - TVM_FFI_ICHECK_EQ( - gemm1_weights.size(2), - (dtype_act == btg::Dtype::E2m1 ? hidden_states.size(1) : hidden_states.size(1) / 2)) - << "the third dimension of weights must be equal to hidden_size."; - - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float8_e4m3fn) - << "gemm1_weights_scale must be fp8."; - - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), local_num_experts) - << "gemm1_weights_scale has incorrect dim 0."; - TVM_FFI_ICHECK_EQ(intermediate_size % sf_vec_size, 0) - << "the second dimension of weights must be a multiple of ", - sf_vec_size; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * intermediate_size) - << "gemm1_weights_scale has incorrect dim 1."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args.hidden_size / sf_vec_size) - << "gemm1_weights_scale has incorrect dim 2."; - - if (gemm1_bias.has_value()) { - TVM_FFI_ICHECK_EQ(gemm1_bias.value().dtype(), dl_float32) - << "gemm1_bias must be float, got " - << tvm::ffi::DLDataTypeToString(gemm1_bias.value().dtype()); - TVM_FFI_ICHECK_EQ(gemm1_bias.value().ndim(), 2) << "gemm1_bias must be 2D."; - TVM_FFI_ICHECK_EQ(gemm1_bias.value().size(0), local_num_experts) - << "gemm1_bias has incorrect dim 0."; - TVM_FFI_ICHECK_EQ(gemm1_bias.value().size(1), 2 * intermediate_size) - << "gemm1_bias has incorrect dim 1."; + // Create a map of launchers for each tile size + std::unordered_map> launchers_map; + + for (int32_t curr_tile_N : selected_tile_nums) { + // Create MoE arguments for this launcher + auto args = std::make_unique(); + args->num_tokens = num_tokens; + args->num_experts = num_experts; + args->hidden_size = hidden_size; + args->hidden_size_output = args->hidden_size; + args->top_k = top_k; + args->n_group = n_group.value_or(0); + args->topk_group = topk_group.value_or(0); + ; + args->local_expert_offset = local_expert_offset; + args->local_num_experts = local_num_experts; + args->intermediate_size = intermediate_size; + + // Create and initialize launcher for this tile size + auto launcher = std::make_unique(routing_logits, routing_bias, hidden_states, + gemm1_weights, gemm2_weights); + launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, + weight_layout); + + launchers_map[curr_tile_N] = std::move(launcher); } - if (gemm1_alpha.has_value()) { - TVM_FFI_ICHECK_EQ(gemm1_alpha.value().dtype(), dl_float32) - << "gemm1_alpha must be float, got " - << tvm::ffi::DLDataTypeToString(gemm1_alpha.value().dtype()); - TVM_FFI_ICHECK_EQ(gemm1_alpha.value().ndim(), 1) << "gemm1_alpha must be 1D."; - TVM_FFI_ICHECK_EQ(gemm1_alpha.value().size(0), local_num_experts) - << "gemm1_alpha has incorrect dim 0."; + // Extract tile_N and config from moe_tactic + int64_t tile_N = moe_tactic[0]; + int64_t config = moe_tactic[1]; + + // Handle default case + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); } - if (gemm1_beta.has_value()) { - TVM_FFI_ICHECK_EQ(gemm1_beta.value().dtype(), dl_float32) - << "gemm1_beta must be float, got " - << tvm::ffi::DLDataTypeToString(gemm1_beta.value().dtype()); - TVM_FFI_ICHECK_EQ(gemm1_beta.value().ndim(), 1) << "gemm1_beta must be 1D."; - TVM_FFI_ICHECK_EQ(gemm1_beta.value().size(0), local_num_experts) - << "gemm1_beta has incorrect dim 0."; + + // Get the launcher for the selected tile_N + auto& selected_launcher = launchers_map.at(tile_N); + + // Run the launcher - it will create its own runner internally + auto result = selected_launcher->run(config, enable_pdl)[0]; + return result; +} + +Tensor trtllm_fp8_per_tensor_scale_moe( + TensorView routing_logits, Optional routing_bias, TensorView hidden_states, + TensorView gemm1_weights, TensorView output1_scales_scalar, + TensorView output1_scales_gate_scalar, TensorView gemm2_weights, + TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k, + Optional n_group, Optional topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, + bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl, + Array config_index) { + // Basic type validation + auto dtype = hidden_states.dtype(); + if (use_routing_scales_on_input) { + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; + } else if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; + } else { + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; } + TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16) + << "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16."; + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) + << "FP8 MoE: gemm1_weights must be float8_e4m3fn."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) + << "FP8 MoE: gemm2_weights must be float8_e4m3fn."; + TVM_FFI_ICHECK_EQ(output1_scales_scalar.dtype(), dl_float32) + << "FP8 MoE: output1_scales_scalar must be float32."; + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.dtype(), dl_float32) + << "FP8 MoE: output1_scales_gate_scalar must be float32."; + TVM_FFI_ICHECK_EQ(output2_scales_scalar.dtype(), dl_float32) + << "FP8 MoE: output2_scales_scalar must be float32."; - TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_uint8) << "gemm2_weights must be byte."; + auto const num_tokens = hidden_states.size(0); + auto const hidden_size = hidden_states.size(1); - TVM_FFI_ICHECK_EQ(gemm2_weights.ndim(), 3) << "gemm2_weights must be 3D."; - // / 2 to compensate for the fact that we pack 2 e2m1 into 1 byte. - TVM_FFI_ICHECK_EQ(gemm2_weights.size(2), intermediate_size / 2) - << "the third dimension of weights must be equal to intermediate_size."; + // Use default values that match the original function behavior + bool use_shuffled_weight = true; // Original uses /*useShuffledMatrixA*/ true + int64_t weight_layout = 0; // Default to MajorK - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float8_e4m3fn) - << "gemm2_weights_scale must be fp8."; + // Calculate supported tile sizes + std::vector mSupportedTileN(Fp8PerTensorLauncher::mSupportedTileNums.begin(), + Fp8PerTensorLauncher::mSupportedTileNums.end()); + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), local_num_experts) - << "gemm2_weights_scale has incorrect dim 0."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args.hidden_size) - << "gemm2_weights_scale has incorrect dim 1."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), intermediate_size / sf_vec_size) - << "gemm2_weights_scale has incorrect dim 2."; + // Create a map of launchers for each tile size + std::unordered_map> launchers_map; + + for (int32_t curr_tile_N : selected_tile_nums) { + // Create MoE arguments for this launcher + auto args = std::make_unique(); + args->num_tokens = num_tokens; + args->num_experts = num_experts; + args->hidden_size = hidden_size; + args->hidden_size_output = args->hidden_size; + args->top_k = top_k; + args->n_group = n_group.value_or(0); + args->topk_group = topk_group.value_or(0); + args->local_expert_offset = local_expert_offset; + args->local_num_experts = local_num_experts; + args->intermediate_size = intermediate_size; + args->routed_scaling_factor = routed_scaling_factor.value_or(1.0); + + // Create and initialize launcher for this tile size + auto launcher = std::make_unique( + routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, + output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar); + // Note: Original code passes tile_N where tile_tokens_dim is expected + // This seems incorrect but we match the original behavior + launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, + weight_layout, use_routing_scales_on_input); - if (output1_scales_scalar.has_value()) { - TVM_FFI_ICHECK_EQ(output1_scales_scalar.value().dtype(), dl_float32) - << "output1_scales_scalar must be float."; - TVM_FFI_ICHECK_EQ(output1_scales_scalar.value().ndim(), 1) - << "output1_scales_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output1_scales_scalar.value().size(0), local_num_experts) - << "output1_scales_scalar has incorrect dim 0."; + launchers_map[curr_tile_N] = std::move(launcher); } - if (output1_scales_gate_scalar.has_value()) { - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value().dtype(), dl_float32) - << "output1_scales_gate_scalar must be float."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value().ndim(), 1) - << "output1_scales_gate_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value().size(0), local_num_experts) - << "output1_scales_gate_scalar has incorrect dim 0."; + // Extract tile_N and config from config_index + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + + // Handle default case + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); } - if (output2_scales_scalar.has_value()) { - TVM_FFI_ICHECK_EQ(output2_scales_scalar.value().dtype(), dl_float32) - << "output2_scales_scalar must be float."; - TVM_FFI_ICHECK_EQ(output2_scales_scalar.value().ndim(), 1) - << "output2_scales_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output2_scales_scalar.value().size(0), local_num_experts) - << "output2_scales_scalar has incorrect dim 0."; + // Get the launcher for the selected tile_N + auto& selected_launcher = launchers_map.at(tile_N); + + // Run the launcher - it will create its own runner internally + auto result = selected_launcher->run(config, enable_pdl, use_routing_scales_on_input)[0]; + // Return the result tensor + return result; +} + +Tensor trtllm_fp8_block_scale_moe( + TensorView routing_logits, Optional routing_bias, TensorView hidden_states, + TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, + TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output, + int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, + int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, + Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, bool enable_pdl, Array config_index) { + // Basic type validation + auto dtype = hidden_states.dtype(); + if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; + } else { + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; } + TVM_FFI_ICHECK(dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) + << "FP8 block scale MoE: hidden_states must be fp16, bf16, or fp8."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) + << "FP8 block scale MoE: hidden_states_scale must be float32."; + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) + << "FP8 block scale MoE: gemm1_weights must be fp8."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) + << "FP8 block scale MoE: gemm1_weights_scale must be float32."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) + << "FP8 block scale MoE: gemm2_weights must be fp8."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) + << "FP8 block scale MoE: gemm2_weights_scale must be float32."; - // setup workspace - workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); - workspace.total_max_padded_tokens = - std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); - workspace.ProjUpTileN = tile_tokens_dim; - workspace.routing_expert_indexes = static_cast(expert_indices.data_ptr()); - workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); - workspace.expanded_idx_to_permuted_idx = static_cast( - expanded_idx_to_permuted_idx.data_ptr()); // Needed by permute/finalize kernels - workspace.permuted_idx_to_token_idx = - static_cast(permuted_idx_to_token_idx.data_ptr()); // Needed by permuteGemm1 kernel - workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel - - workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx.data_ptr()); - workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit.data_ptr()); - workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas.data_ptr()); - - workspace.hidden_states_scale_linear = nullptr; - - // gemm1 intermediate ws - workspace.gemm1_output = gemm1_output.data_ptr(); - workspace.gemm1_output_scale = gemm1_output_scale.has_value() - ? static_cast(gemm1_output_scale.value().data_ptr()) - : nullptr; - // gemm2 intermediate ws - workspace.gemm2_output = gemm2_output.data_ptr(); - workspace.gemm2_output_scale = nullptr; - args.output = output.data_ptr(); - args.output_scale = nullptr; - args.output1_scales_scalar = output1_scales_scalar.has_value() - ? static_cast(output1_scales_scalar.value().data_ptr()) - : nullptr; - args.output1_scales_gate_scalar = - output1_scales_gate_scalar.has_value() - ? static_cast(output1_scales_gate_scalar.value().data_ptr()) - : nullptr; - args.output2_scales_scalar = output2_scales_scalar.has_value() - ? static_cast(output2_scales_scalar.value().data_ptr()) - : nullptr; - args.do_finalize = do_finalize; - - auto const workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); - - Tensor workspace_fc1 = - alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device()); - Tensor workspace_fc2 = - alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states.device()); - workspace.bmm1_workspace = workspace_fc1.data_ptr(); - workspace.bmm2_workspace = workspace_fc2.data_ptr(); - cudaStream_t moe_stream = get_stream(hidden_states.device()); - moe_runner.run(args, workspace, hidden_states.device().device_id, moe_stream, moeConfigIndex, - enable_pdl); - - if (!do_finalize) { - return {gemm2_output, expanded_idx_to_permuted_idx}; + auto const num_tokens = hidden_states.size(0); + auto const hidden_size = hidden_states.size(1); + + std::vector mSupportedTileN(Fp8BlockScaleLauncher::mSupportedTileNums.begin(), + Fp8BlockScaleLauncher::mSupportedTileNums.end()); + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + + // Create a map of launchers for each tile size + std::unordered_map> launchers_map; + + for (int32_t curr_tile_N : selected_tile_nums) { + // Create MoE arguments for this launcher + auto args = std::make_unique(); + args->num_tokens = num_tokens; + args->num_experts = num_experts; + args->hidden_size = hidden_size; + args->hidden_size_output = args->hidden_size; + args->top_k = top_k; + args->n_group = n_group.value_or(0); + args->topk_group = topk_group.value_or(0); + args->local_expert_offset = local_expert_offset; + args->local_num_experts = local_num_experts; + args->intermediate_size = intermediate_size; + args->routed_scaling_factor = routed_scaling_factor.value_or(1.0); + + // Create and initialize launcher for this tile size + auto launcher = std::make_unique( + routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, + gemm1_weights_scale, gemm2_weights, gemm2_weights_scale); + // Note: Original code passes tile_N where tile_tokens_dim is expected + // This seems incorrect but we match the original behavior + launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, + weight_layout); + + launchers_map[curr_tile_N] = std::move(launcher); } - return {}; + + // Extract tile_N and config from config_index + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + + // Handle default case + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); + } + + // Get the launcher for the selected tile_N + auto& selected_launcher = launchers_map.at(tile_N); + + // Run the launcher with DeepSeek FP8 enabled - it will create its own runner internally + auto result = selected_launcher->run(config, enable_pdl, false /* use_routing_scales_on_input */, + true /* use_deep_seek_fp8 */)[0]; + // Return the result tensor + return result; } Array trtllm_fp4_block_scale_moe( @@ -1188,26 +1508,47 @@ Array trtllm_fp4_block_scale_moe( int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t gated_act_type, TensorView output, Array config_index) { - using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; - + // Determine data types based on input format int const num_tokens = hidden_states.size(0); int hidden_size = hidden_states.size(1); if (hidden_states.dtype() == dl_uint8) hidden_size *= 2; + int hidden_states_scale_vec_size = -1; if (hidden_states_scale.has_value()) { hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel(); } int weight_scale_vec_size = (local_num_experts * intermediate_size * 2 * hidden_size) / gemm1_weights_scale.numel(); + TVM_FFI_ICHECK(weight_scale_vec_size == 16 || weight_scale_vec_size == 32) << "unsupported weight_scale_vec_size."; auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1; + if (routing_logits.has_value()) { + TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || + routing_logits.value().dtype() == dl_bfloat16) + << "routing_logits must be float or bfloat16."; + TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts) + << "routing_logits has incorrect shape."; + } + if (routing_bias.has_value()) { + TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || + routing_bias.value().dtype() == dl_float32) + << "routing_bias must be bfloat16 or float."; + + TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; + TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts) + << "routing_bias has incorrect shape."; + } + + // Determine activation type TVM_FFI_ICHECK(gemm1_weights.dtype() == dl_uint8 && gemm2_weights.dtype() == dl_uint8) << "weights must be fp4 packed in uint8."; TVM_FFI_ICHECK(hidden_states.dtype() == dl_uint8 || hidden_states.dtype() == dl_bfloat16 || hidden_states.dtype() == dl_float8_e4m3fn) << "hidden_states must be bf16, fp8 or uint8 (packed fp4)."; + auto mDtypeAct = btg::Dtype::Bfloat16; if (hidden_states.dtype() == dl_uint8) { TVM_FFI_ICHECK(hidden_states_scale.has_value() && @@ -1231,75 +1572,61 @@ Array trtllm_fp4_block_scale_moe( mDtypeAct = btg::Dtype::E4m3; } } - bool mUseDeepSeekFp8{false}; // FP4 doesn't use DeepSeek FP8 - std::vector mSupportedTileN = {8, 16, 32, 64}; - if (mDtypeAct != btg::Dtype::Bfloat16) { - mSupportedTileN.push_back(128); - } - if ((mDtypeAct == btg::Dtype::MxE4m3 && mDtypeWeights == btg::Dtype::MxE2m1) || - (mDtypeAct == btg::Dtype::E2m1 && mDtypeWeights == btg::Dtype::E2m1)) { - // MxFP4 x MxFP4 or NvFP4 x NvFP4 - mSupportedTileN.push_back(256); - } + // Determine supported tile sizes + std::vector mSupportedTileN = FP4BlockScaleLauncher::getSupportedTileNums(mDtypeAct); std::set selected_tile_nums = computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); - // Build runners for all supported tile sizes - std::unordered_map> mRunners; - for (int32_t tile_N : selected_tile_nums) { - mRunners.emplace(tile_N, - std::make_unique(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, tile_N, - static_cast(gated_act_type), - /*useShuffledMatrixA*/ true)); + + // Create a map of launchers for each tile size + std::unordered_map> launchers_map; + + for (int32_t curr_tile_N : selected_tile_nums) { + // Create MoE arguments for this launcher + auto args = std::make_unique(); + args->num_tokens = num_tokens; + args->num_experts = num_experts; + // For E2m1, hidden_size is already multiplied by 2 above, so use it directly + args->hidden_size = hidden_size; + args->hidden_size_output = args->hidden_size; + args->top_k = top_k; + args->n_group = n_group.value_or(0); + args->topk_group = topk_group.value_or(0); + args->local_expert_offset = local_expert_offset; + args->local_num_experts = local_num_experts; + args->intermediate_size = intermediate_size; + args->routed_scaling_factor = routed_scaling_factor.value_or(1.0); + args->do_finalize = do_finalize; + args->output = output.data_ptr(); + args->output_scale = nullptr; + + // Create and initialize launcher for this tile size + auto launcher = std::make_unique( + routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, + gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, + gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, + output2_scales_scalar, topk_ids, expert_weights); + launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, + /*weight_layout=*/0, gated_act_type, mDtypeAct, mDtypeWeights); + + launchers_map[curr_tile_N] = std::move(launcher); } - // moeConfigIndex corresponds to pair (tile_N, config) + // Extract tile_N and config from config_index int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Autotuner has requested a default or 'fallback' config index + + // Handle default case if (tile_N == -1 || config == -1) { tile_N = *selected_tile_nums.begin(); - config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, - local_num_experts, num_tokens); + config = -1; // Let the runner choose default } - return trtllm_fp4_block_scale_moe_launcher( - routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale, - gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, - gemm2_weights, gemm2_weights_scale, gemm2_bias, output1_scales_scalar, - output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group, - intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tile_N, - routing_method_type, do_finalize, *mRunners[tile_N], mDtypeAct, mDtypeWeights, config, - enable_pdl, output); -} -int64_t trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const dtype_weights_, - bool const useDeepSeekFp8, int64_t const top_k, - int64_t const hidden_size, int64_t const intermediate_size, - int64_t const num_local_experts, - int64_t const gated_act_type, int64_t const num_tokens) { - auto dtype_act = static_cast(dtype_act_); - auto dtype_weights = static_cast(dtype_weights_); - std::vector supported_tile_nums = {8, 16, 32, 64}; - // Check if we should add tile size 128 - bool is_fp4_without_bf16_act = - (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) && - dtype_act != btg::Dtype::Bfloat16; - bool is_fp8_per_tensor = - dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8; - - if (is_fp4_without_bf16_act || is_fp8_per_tensor) { - supported_tile_nums.push_back(128); - } - std::set selected_tile_nums = - computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); - - std::unique_ptr moe_runner = - std::make_unique( - dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin(), - static_cast(gated_act_type), /*useShuffledMatrixA*/ true); + // Get the launcher for the selected tile_N + auto& selected_launcher = launchers_map.at(tile_N); - return moe_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens); + // Run the launcher - it will create its own runner internally + return selected_launcher->run(config, enable_pdl); } Array> trtllm_get_valid_moe_configs( @@ -1307,68 +1634,53 @@ Array> trtllm_get_valid_moe_configs( int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size, int64_t const num_local_experts, int64_t const gated_act_type, bool const use_shuffled_weight, int64_t const weight_layout, int64_t const num_tokens) { - // returns (tile_N, config) - Array> valid_configs; auto dtype_act = static_cast(dtype_act_); auto dtype_weights = static_cast(dtype_weights_); - std::vector supported_tile_nums = {8, 16, 32, 64}; - // Check if we should add tile size 128 - bool is_fp4_without_bf16_act = - (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) && - dtype_act != btg::Dtype::Bfloat16; - bool is_fp8_per_tensor = - dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8; - - if (useDeepSeekFp8) { - supported_tile_nums.push_back(128); - } else if (is_fp8_per_tensor) { - supported_tile_nums.push_back(128); - supported_tile_nums.push_back(192); - supported_tile_nums.push_back(256); - } else if (is_fp4_without_bf16_act) { - supported_tile_nums.push_back(128); - } - - if ((dtype_act == btg::Dtype::MxE4m3 && dtype_weights == btg::Dtype::MxE2m1) || - (dtype_act == btg::Dtype::E2m1 && dtype_weights == btg::Dtype::E2m1)) { - // MxFP4 x MxFP4 or NvFP4 x NvFP4 - supported_tile_nums.push_back(256); - } - std::set selected_tile_nums = - computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); - for (int32_t tile_N : selected_tile_nums) { - std::unique_ptr moe_runner; - - if (dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3) { - // FP8 block scale MOE runner - moe_runner = std::make_unique( - dtype_weights, useDeepSeekFp8, tile_N, use_shuffled_weight, - static_cast(weight_layout)); + if (dtype_act == btg::Dtype::Bfloat16 && dtype_weights == btg::Dtype::Bfloat16) { + // BF16 MoE + return Bf16MoeLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens, gated_act_type, + use_shuffled_weight, weight_layout); + + } else if (dtype_act == btg::Dtype::E4m3 && dtype_weights == btg::Dtype::E4m3) { + // FP8 + if (!useDeepSeekFp8) { + // FP8 per-tensor scale + return Fp8PerTensorLauncher::getValidConfigs( + top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, gated_act_type, + use_shuffled_weight, weight_layout, dtype_act, dtype_weights); } else { - // FP4 block scale MOE runner - moe_runner = std::make_unique( - dtype_act, dtype_weights, useDeepSeekFp8, tile_N, - static_cast(gated_act_type), - /*useShuffledMatrixA*/ true); - } - auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens); - for (auto cfg : cfgs) { - valid_configs.push_back({tile_N, cfg}); + // FP8 block scale + return Fp8BlockScaleLauncher::getValidConfigs( + top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, use_shuffled_weight, + weight_layout, dtype_weights); } + } else if (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) { + // FP4 block scale + return FP4BlockScaleLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens, gated_act_type, + dtype_act, dtype_weights); } - return valid_configs; + + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "Unsupported data type combination for getValidConfigs: " + << "dtype_act=" << static_cast(dtype_act) + << ", dtype_weights=" << static_cast(dtype_weights) + << ", useDeepSeekFp8=" << useDeepSeekFp8; + + // Unreachable code - added to suppress compiler warning + return Array>(); } namespace trtllm_cubin_loader { #include } +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_bf16_moe, trtllm_bf16_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_per_tensor_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_block_scale_moe, trtllm_fp8_block_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp4_block_scale_moe, trtllm_fp4_block_scale_moe); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_get_default_moe_configs, trtllm_get_default_moe_configs); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_get_valid_moe_configs, trtllm_get_valid_moe_configs); } // namespace flashinfer diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index d3a63431a8..40d0fe90cb 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -435,8 +435,8 @@ void run(Data const& data, void* stream) { << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; // FIXME: routingIndicesBlockKernel breaks the vllm + gpt-oss DeepEP - // bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; - bool const useSingleBlock = false; + bool const useSingleBlock = + data.mNumTokens <= BlockKernelMaxNumTokens && data.mPtrTopKPacked == nullptr; bool const useSingleCluster = data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 733b7aed24..520a3e1c6f 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -89,7 +89,7 @@ class ArtifactPath: TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "23daeee32b60bde7947ce1ee7a58d4ab701f134b/batched_gemm-0d28130-add42d1" + "c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988" ) TRTLLM_GEN_GEMM: str = ( "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" @@ -105,7 +105,7 @@ class MetaInfoHash: "2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a" ) TRTLLM_GEN_BMM: str = ( - "6cfade1395f9648aba5dcf2c329114619e175c0f238882555178f98c8f5c1968" + "26c51b75921be90235d193675facdea5d8341c4c52c73bd0a7c8e787c0388beb" ) TRTLLM_GEN_GEMM: str = ( "bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340" @@ -123,7 +123,7 @@ class CheckSumHash: "639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f" ) TRTLLM_GEN_BMM: str = ( - "46ccf0492e3ed10135c2861a4f4ef9bb45846610f9a9d2ccaf2d5bf01d2006fd" + "85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf" ) DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" TRTLLM_GEN_GEMM: str = ( diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index 2759105691..8121c99c0a 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -29,6 +29,7 @@ trtllm_fp4_block_scale_routed_moe, trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, + trtllm_bf16_moe, ) __all__ = [ @@ -40,8 +41,11 @@ "gen_cutlass_fused_moe_sm120_module", "gen_cutlass_fused_moe_sm100_module", "gen_cutlass_fused_moe_sm90_module", + "gen_trtllm_gen_fused_moe_sm100_module", "reorder_rows_for_gated_act_gemm", + "trtllm_bf16_moe", "trtllm_fp4_block_scale_moe", + "trtllm_fp4_block_scale_routed_moe", "trtllm_fp8_block_scale_moe", "trtllm_fp8_per_tensor_scale_moe", ] diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 516e05c8fc..83f186673b 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -46,6 +46,7 @@ get_shuffle_matrix_sf_a_row_indices, register_custom_op, register_fake_op, + get_compute_capability, ) from .utils import ( get_last_power_of_2_num_tokens_buckets, @@ -177,6 +178,40 @@ class GatedActType(IntEnum): GeGlu = 1 +@functools.cache +def is_trtllm_moe_supported( + dtype_weights: DtypeTrtllmGen, + dtype_act: DtypeTrtllmGen, + quant_method: Optional[str] = None, +) -> bool: + arch = get_compute_capability(torch.cuda.current_device()) + if arch[0] < 10: + return False + if dtype_weights not in [ + DtypeTrtllmGen.Bfloat16, + DtypeTrtllmGen.E4m3, + DtypeTrtllmGen.E2m1, + DtypeTrtllmGen.MxE2m1, + ]: + return False + if ( + dtype_weights == DtypeTrtllmGen.Bfloat16 + and dtype_act != DtypeTrtllmGen.Bfloat16 + ): + return False + if dtype_weights == DtypeTrtllmGen.E4m3 and dtype_act != DtypeTrtllmGen.E4m3: + return False + if dtype_weights == DtypeTrtllmGen.E2m1 and dtype_act != DtypeTrtllmGen.E2m1: + return False + if dtype_weights == DtypeTrtllmGen.MxE2m1 and dtype_act not in [ + DtypeTrtllmGen.MxE2m1, + DtypeTrtllmGen.MxE4m3, + DtypeTrtllmGen.Bfloat16, + ]: + return False + return True + + def _maybe_get_cached_w3_w1_permute_indices( _cache_permute_indices, dst_w3_w1_weight: torch.Tensor, @@ -947,15 +982,6 @@ def __init__( self.gated_act_type = GatedActType(gated_act_type) self.use_shuffled_weight = use_shuffled_weight self.weight_layout = WeightLayout(weight_layout) - if ( - not self.use_shuffled_weight - or self.weight_layout != WeightLayout.MajorK - ): - assert ( - self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3 - ), ( - "use_shuffled_weight is False or weight_layout is not MajorK is only supported for FP8 block scale" - ) def get_valid_tactics( self, @@ -1037,7 +1063,28 @@ def forward( and hidden_states_scale.shape[0] == num_tokens ), "hidden_states_scale's first dimension must be batch size" # Choose the appropriate operation based on data types - if ( + if self.dtype_weights == DtypeTrtllmGen.Bfloat16: + # BF16 operations + moe_op.trtllm_bf16_moe( + routing_logits, + kwargs["routing_bias"], + hidden_states, + kwargs["gemm1_weights"], + kwargs["gemm2_weights"], + kwargs["num_experts"], + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + self.num_local_experts, + kwargs["routing_method_type"], + kwargs["use_shuffled_weight"], + kwargs["weight_layout"], + kwargs["enable_pdl"], + [-1, -1] if tactic == -1 else tactic, + ) + elif ( self.dtype_act == DtypeTrtllmGen.E4m3 and self.dtype_weights == DtypeTrtllmGen.E4m3 ): @@ -1163,6 +1210,134 @@ def refine_tuning_config(cls, tune_max_num_tokens: int): ), ) + @register_custom_op( + "flashinfer::trtllm_bf16_moe", + mutates_args=(""), + ) + def trtllm_bf16_moe_op( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routing_method_type: int, + use_shuffled_weight: bool, + weight_layout: int, + enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, + ) -> torch.Tensor: + if enable_pdl is None: + enable_pdl = device_support_pdl(hidden_states.device) + + # Use AutoTuner to select the best tactic + tuner = AutoTuner.get() + MoERunner.refine_tuning_config(tune_max_num_tokens) + + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] + + # Create workspace buffers + output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device + ) + topk_ids = torch.empty( + num_tokens, top_k, dtype=torch.int32, device=hidden_states.device + ) + expert_weights = torch.empty( + num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device + ) + + dtype_act = DtypeTrtllmGen.Bfloat16 + dtype_weights = DtypeTrtllmGen.Bfloat16 + + moe_runner = MoERunner( + top_k=top_k, + num_local_experts=local_num_experts, + dtype_act=dtype_act, + dtype_weights=dtype_weights, + use_deepseek_fp8=False, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + weight_layout=weight_layout, + use_shuffled_weight=use_shuffled_weight, + gated_act_type=GatedActType.SwiGlu, # Default for BF16 + ) + + inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] + + _, tactic = tuner.choose_one( + "flashinfer::trtllm_bf16_moe", + [moe_runner], + MoERunner.tuning_config_no_hidden_states_scales, + inputs, + routing_bias=routing_bias, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + n_group=n_group, + topk_group=topk_group, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routing_method_type=routing_method_type, + use_shuffled_weight=use_shuffled_weight, + weight_layout=weight_layout, + enable_pdl=enable_pdl, + ) + + # Call the C++ function with the selected tactic + result = moe_op.trtllm_bf16_moe( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routing_method_type, + use_shuffled_weight, + weight_layout, + enable_pdl, + [-1, -1] if tactic == -1 else tactic, + ) + return result + + @register_fake_op("flashinfer::trtllm_bf16_moe") + def _fake_trtllm_bf16_moe( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routing_method_type: int, + use_shuffled_weight: bool, + weight_layout: int, + enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, + ): + seq_len = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] + + return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + @register_custom_op( "flashinfer::trtllm_fp8_per_tensor_scale_moe", mutates_args=(""), @@ -1248,7 +1423,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( enable_pdl=enable_pdl, ) # Call the C++ function - moe_op.trtllm_fp8_per_tensor_scale_moe( + result = moe_op.trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, hidden_states, @@ -1271,7 +1446,8 @@ def trtllm_fp8_per_tensor_scale_moe_op( enable_pdl, [-1, -1] if tactic == -1 else tactic, ) - return output + + return result @register_fake_op("flashinfer::trtllm_fp8_per_tensor_scale_moe") def _fake_trtllm_fp8_per_tensor_scale_moe( @@ -1395,7 +1571,7 @@ def trtllm_fp8_block_scale_moe_op( enable_pdl=enable_pdl, ) # Call the C++ function for block scale MoE - moe_op.trtllm_fp8_block_scale_moe( + result = moe_op.trtllm_fp8_block_scale_moe( routing_logits, routing_bias, hidden_states, @@ -1420,7 +1596,7 @@ def trtllm_fp8_block_scale_moe_op( [-1, -1] if tactic == -1 else tactic, ) - return output + return result @register_fake_op("flashinfer::trtllm_fp8_block_scale_moe") def _fake_trtllm_fp8_block_scale_moe( @@ -1671,12 +1847,93 @@ def _fake_trtllm_fp4_block_scale_moe( return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] return SimpleNamespace( + trtllm_bf16_moe=trtllm_bf16_moe_op, trtllm_fp8_per_tensor_scale_moe=trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp8_block_scale_moe=trtllm_fp8_block_scale_moe_op, trtllm_fp4_block_scale_moe=trtllm_fp4_block_scale_moe_op, ) +def trtllm_bf16_moe( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routing_method_type: int = 0, + use_shuffled_weight: bool = True, + weight_layout: int = WeightLayout.BlockMajorK, + enable_pdl: bool = True, + tune_max_num_tokens: int = 8192, +) -> torch.Tensor: + """BF16 MoE operation with autotuning support. + + This function implements a bfloat16 Mixture of Experts layer using the TensorRT-LLM backend + with automatic performance tuning for optimal tile size selection. + + Args: + routing_logits: [seq_len, num_experts] tensor of routing logits. + Supports float32 or bfloat16. + routing_bias: Optional [num_experts] tensor of routing bias. + Must be bfloat16 if provided. + hidden_states: [seq_len, hidden_size] tensor of input hidden states. + Must be bfloat16. + gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights. + Must be bfloat16. + gemm2_weights: [num_experts, hidden_size, intermediate_size] tensor of second layer weights. + Must be bfloat16. + num_experts: Total number of experts. + top_k: Number of experts to route to per token. + n_group: Number of expert groups. + topk_group: Number of groups to consider for top-k routing. + intermediate_size: Size of intermediate layer. + local_expert_offset: Offset of local experts in global expert space. + local_num_experts: Number of experts handled by this device. + routing_method_type: Type of routing method to use (default: 0). + - 0: Default (Softmax -> TopK) + - 1: Renormalize (TopK -> Softmax) + - 2: DeepSeekV3 (Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts) + - 3: Llama4 (Top1 -> Sigmoid) + - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) + use_shuffled_weight: Whether to use shuffled weight layout for optimization (default: True). + weight_layout: Weight layout format (default: WeightLayout.BlockMajorK). + - 0: MajorK - K-major layout [Mn, K] + - 1: MajorMn - M-major for A and N-major for B [K, Mn] + - 2: BlockMajorK - Blocked along K dimension [K/blockK, Mn, blockK] + enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90. + tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192). + + Returns: + torch.Tensor: Output tensor of shape [seq_len, hidden_size]. + """ + return get_trtllm_moe_sm100_module().trtllm_bf16_moe( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routing_method_type, + use_shuffled_weight, + weight_layout, + enable_pdl, + tune_max_num_tokens, + ) + + def trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 65f497ad90..1427b15245 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -40,6 +40,7 @@ trtllm_fp4_block_scale_moe, trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, + trtllm_bf16_moe, ) from flashinfer.fused_moe.core import ( get_w2_permute_indices_with_cache, @@ -218,6 +219,7 @@ class QuantMode(IntEnum): FP4_MXFP4_Bf16 = 3 FP8_BLOCK_SCALE = 4 FP8_PER_TENSOR = 5 + BF16 = 6 # ==================================================================================== @@ -794,7 +796,6 @@ def call_moe( weight_layout=static_data["weight_layout"], enable_pdl=enable_pdl, ) - return output.to(torch.float) def compute_reference(self, args): @@ -982,6 +983,155 @@ def get_tolerances(self): return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} +# ==================================================================================== +# BF16 Implementation +# ==================================================================================== + + +class BF16Moe(Moe): + """BF16 MoE implementation.""" + + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): + """No scaling for weights.""" + return { + "hidden_states_scale_global": None, + "gemm1_weights": gemm1_weights.to(torch.bfloat16), + "gemm1_scales": None, + "gemm1_scales_global": None, + "gemm2_weights": gemm2_weights.to(torch.bfloat16), + "gemm2_scales": None, + "gemm2_scales_global": None, + } + + def quantize_inputs(self, hidden_states, *unused_args): + """No scaling for hidden states.""" + return { + "hidden_states": hidden_states.to(torch.bfloat16), + "hidden_states_scale": None, + } + + def prepare_static_weights_for_kernel( + self, + args_dequant, + args, + gemm1_weights_orig, + gemm2_weights_orig, + hidden_size, + intermediate_size, + num_experts, + weight_processing, + ): + """Prepare quantized weights for kernel (done offline with weights).""" + + # Use shuffled weights with BlockMajorK layout for better performance + use_shuffled_weight = weight_processing["use_shuffled_weight"] + weight_layout = weight_processing["layout"] + + if use_shuffled_weight: + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Reorder rows of W1 for fused gated activation and shuffle for both W1 and W2 + # Using cached permute index calculation can speed up weights preprocessing + gemm1_weights_bf16_shuffled = [] + gemm2_weights_bf16_shuffled = [] + for i in range(num_experts): + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + args.gemm1_weights[i].view(torch.uint8), + epilogue_tile_m, + ) + tmp_weights1 = ( + args.gemm1_weights[i] + .view(torch.uint8)[permute_indices.to(args.gemm1_weights.device)] + .contiguous() + ) + + permute_indices = get_w2_permute_indices_with_cache( + self._cache_permute_indices, + args.gemm2_weights[i].view(torch.uint8), + epilogue_tile_m, + ) + tmp_weights2 = ( + args.gemm2_weights[i] + .view(torch.uint8)[permute_indices.to(args.gemm2_weights.device)] + .contiguous() + ) + + if weight_layout == WeightLayout.BlockMajorK: + block_k = 128 + tmp_weights1 = convert_to_block_layout( + tmp_weights1.view(torch.uint8), block_k + ) + tmp_weights2 = convert_to_block_layout( + tmp_weights2.view(torch.uint8), block_k + ) + + gemm1_weights_bf16_shuffled.append(tmp_weights1.view(torch.bfloat16)) + gemm2_weights_bf16_shuffled.append(tmp_weights2.view(torch.bfloat16)) + + # Stack weights for all experts + gemm1_weights_bf16_shuffled = ( + torch.stack(gemm1_weights_bf16_shuffled) + .view(torch.bfloat16) + .contiguous() + ) + gemm2_weights_bf16_shuffled = ( + torch.stack(gemm2_weights_bf16_shuffled) + .view(torch.bfloat16) + .contiguous() + ) + + return { + "gemm1_weights": gemm1_weights_bf16_shuffled, + "gemm2_weights": gemm2_weights_bf16_shuffled, + "use_shuffled_weight": use_shuffled_weight, + "weight_layout": weight_layout, + } + + def call_moe( + self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs + ): + """Call MoE with runtime input quantization + kernel execution (done at runtime).""" + expert_logits = kwargs["expert_logits"] + routing_bias = kwargs["routing_bias"] + num_experts = kwargs["num_experts"] + top_k = kwargs["top_k"] + n_groups = kwargs["n_groups"] + top_k_groups = kwargs["top_k_groups"] + intermediate_size = kwargs["intermediate_size"] + routing_method_type = kwargs["routing_method_type"] + + # Use autotuner for optimal kernel selection + with autotune(True): + output = trtllm_bf16_moe( + expert_logits, # float + routing_bias, + hidden_states_orig, + static_data["gemm1_weights"], + static_data["gemm2_weights"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + use_shuffled_weight=static_data["use_shuffled_weight"], + weight_layout=static_data["weight_layout"], + routing_method_type=routing_method_type, + ) + return output.to(torch.float) + + def compute_reference(self, args): + """BF16 reference implementation.""" + return run_moe_reference_bf16(args) + + def get_tolerances(self): + """Get BF16 accuracy tolerances.""" + return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} + + # ==================================================================================== # Quantizer Factory # ==================================================================================== @@ -1273,8 +1423,6 @@ def check_accuracy(a, b, atol, rtol, percent): count = torch.sum(left > right) mismatch_percent = count / a.numel() if mismatch_percent > 1 - percent: - print(a) - print(b) raise Exception( f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " f"(threshold: {1 - percent:.4f})" @@ -1581,6 +1729,9 @@ def run_moe_dequant(args, quant_mode: QuantMode): .to(torch.float) ) args.c_global_sf = 1.0 + elif quant_mode == QuantMode.BF16: + activation_output = activation_output.to(torch.bfloat16).to(torch.float) + args.c_global_sf = 1.0 else: # mxfp4Bf16 activation_output = activation_output.to(torch.bfloat16).to(torch.float) args.c_global_sf = 1.0 @@ -1786,6 +1937,37 @@ def run_moe_reference_per_tensor_scale_fp8(args): return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant +def run_moe_reference_bf16(args): + """BF16 reference implementation.""" + + # no scaling for hidden states and weights + hidden_states_dequant = args.hidden_states.to(torch.float) + gemm1_weights_dequant = {} + for i in range(args.num_experts): + gemm1_weights_dequant[i] = args.gemm1_weights[i].to(torch.float) + gemm2_weights_dequant = {} + for i in range(args.num_experts): + gemm2_weights_dequant[i] = args.gemm2_weights[i].to(torch.float) + + args_dequant = moe_args_dequant( + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.padding, + hidden_states_dequant, + args.expert_logits, + gemm1_weights_dequant, + gemm2_weights_dequant, + args.permute_info, + args.use_routing_scales_on_input, + GatedActType.SwiGlu.value, # gated_act_type + ) + + return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant + + def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): """Unified actual computation that delegates to implementation-specific methods.""" # 1. Prepare static weights for the kernel (offline processing) @@ -2085,12 +2267,13 @@ def run_moe_test( # Test: Renormalize routing -@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("num_tokens", [1, 8, 1024, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) -@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize("intermediate_size", [1024, 768, 512, 384]) @pytest.mark.parametrize( "moe_impl", [ + pytest.param(BF16Moe(), id="BF16xBF16"), pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), @@ -2100,6 +2283,21 @@ def run_moe_test( @pytest.mark.parametrize( "routing_config", [ + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe, BF16Moe], + "compatible_intermediate_size": [384, 768, 1024], + }, + id="Qwen3", + ), pytest.param( { "num_experts": 256, @@ -2110,8 +2308,8 @@ def run_moe_test( "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], - "compatible_intermediate_size": [384, 768, 1024, 2048], + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe, BF16Moe], + "compatible_intermediate_size": [384, 1024], }, id="Renorm", ), @@ -2125,7 +2323,7 @@ def run_moe_test( "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe, BF16Moe], "compatible_intermediate_size": [512], }, id="Qwen3_next", @@ -2135,6 +2333,14 @@ def run_moe_test( @pytest.mark.parametrize( "weight_processing", [ + pytest.param( + { + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="NoShuffle_MajorK", + ), pytest.param( { "use_shuffled_weight": True, @@ -2143,6 +2349,14 @@ def run_moe_test( }, id="Shuffled_MajorK", ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.BlockMajorK, + "compatible_moe_impls": [FP8BlockScaleMoe, BF16Moe], + }, + id="Shuffled_BlockMajorK", + ), ], ) @pytest.mark.parametrize( @@ -2176,7 +2390,7 @@ def test_renormalize_routing( # Test: DeepSeekV3 routing -@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("num_tokens", [1, 8, 1024, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( @@ -2202,7 +2416,7 @@ def test_renormalize_routing( "has_routing_bias": True, "routing_method_type": RoutingMethodType.DeepSeekV3, "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], - "compatible_intermediate_size": [512, 1024, 2048], + "compatible_intermediate_size": [1024, 2048], }, id="kimi_k2", ), From d56748ffbce99d8cd8db20688267ae254b30cb16 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Sun, 9 Nov 2025 08:34:34 +0800 Subject: [PATCH 043/130] Fix: several bugs/issues with trtllm-gen attention kernels. (#2062) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This MR fixes: 1. unspecified cuda launch errors with 2CTA MLA kernels 2. masking bug of SWA decode kernels. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Added Sparse MLA support and propagated its flag through kernel selection and dispatch. * **Bug Fixes / Improvements** * Enforced power-of-two page sizing for paged KV caches and tightened head-dimension limits for broader hardware compatibility. * Updated kernel trait encoding and hash construction to include the sparse MLA flag and revised bit-field layout. * **Chores** * Updated runtime kernel artifact identifiers and checksums. * Extended kernel parameter fields, zero-initialized params on setup, and populated tokens-per-page log2 for paged KV. --------- Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Co-authored-by: yzh119 Co-authored-by: Zihao Ye --- flashinfer/artifacts.py | 4 +- .../flashinfer/trtllm/fmha/fmhaKernels.cuh | 43 ++++++++++--------- include/flashinfer/trtllm/fmha/kernelParams.h | 20 +++++++++ 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 520a3e1c6f..f88328f4a1 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -87,7 +87,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "b793e1b2cf7c419f070372ba55bbe53ca6fb9016/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988" ) @@ -120,7 +120,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f" + "20c017db0761a30130f05080ed2078f6c8044c0c2b3be7c4353ec740034b4432" ) TRTLLM_GEN_BMM: str = ( "85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf" diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index d3e2b89c85..5bd91f4064 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -96,14 +96,15 @@ class TllmGenFmhaKernel { inline uint64_t hashID(int qkvLayout, int maskType, int kernelType, int scheduler, int multiCtasKvMode, int headDimPerCtaV, int headDimQk, int headDimV, int tileSizeKv, int numTokensPerPage, int maxNumHeadsQPerKvInCta, - bool reuseSmemKForV, bool uses2CtaMma) const { + bool reuseSmemKForV, bool uses2CtaMma, bool sparseMla) const { FLASHINFER_CHECK((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) && - (headDimPerCtaV <= 2048) && (headDimQk <= 2048) && (headDimV <= 2048) && - (numTokensPerPage <= 128), - "Expect (32 <= headDim <= 2048) && (numTokensPerPage <= 128), " - "got headDimPerCtaV=%d, headDimQk=%d, " - "headDimV=%d, numTokensPerPage=%d", - headDimPerCtaV, headDimQk, headDimV, numTokensPerPage); + (headDimPerCtaV <= 1024) && (headDimQk <= 1024) && (headDimV <= 1024), + "Expect (32 <= headDim <= 1024), got headDimPerCtaV=%d, headDimQk=%d, " + "headDimV=%d", + headDimPerCtaV, headDimQk, headDimV); + // The numTokensPerPage must be power of 2. + FLASHINFER_CHECK((numTokensPerPage & (numTokensPerPage - 1)) == 0, + "The numTokensPerPage must be power of 2."); FLASHINFER_CHECK(maxNumHeadsQPerKvInCta <= 128, "The maxNumHeadsQPerKvInCta <= 128 is required."); FLASHINFER_CHECK(tileSizeKv == 64 || tileSizeKv == 128, "The tileSizeKv must be 64 or 128."); @@ -113,25 +114,26 @@ class TllmGenFmhaKernel { // Bit 8 - 11: kernelType. // Bit 12 - 15: tileScheduler. // Bit 16 - 17: multiCtasKvMode. - // Bit 18 - 24: (headDimPerCtaV >> 5). - // Bit 25 - 31: (headDimQk >> 5). - // Bit 32 - 38: (headDimV >> 5). - // Bit 39 - 40: (tileSizeKv >> 6). - // Bit 41 - 48: numTokensPerPage. + // Bit 18 - 25: (headDimPerCtaV >> 3). + // Bit 26 - 33: (headDimQk >> 3). + // Bit 34 - 41: (headDimV >> 3). + // Bit 42 - 43: (tileSizeKv >> 6). + // Bit 44 - 48: (log2(numTokensPerPage)). // Bit 49 - 56: maxNumHeadsQPerKvInCta. // Bit 57 - 57: reuseSmemKForV. // Bit 58 - 58: uses2CtaMma. + // Bit 59 - 59: sparseMla. return (static_cast(qkvLayout) << 0) | (static_cast(maskType) << 4) | (static_cast(kernelType) << 8) | (static_cast(scheduler) << 12) | (static_cast(multiCtasKvMode) << 16) | - (static_cast(headDimPerCtaV >> 5) << 18) | - (static_cast(headDimQk >> 5) << 25) | - (static_cast(headDimV >> 5) << 32) | - (static_cast(tileSizeKv >> 6) << 39) | - (static_cast(numTokensPerPage) << 41) | + (static_cast(headDimPerCtaV >> 3) << 18) | + (static_cast(headDimQk >> 3) << 26) | + (static_cast(headDimV >> 3) << 34) | + (static_cast(tileSizeKv >> 6) << 42) | + (static_cast(log2(numTokensPerPage)) << 44) | (static_cast(maxNumHeadsQPerKvInCta) << 49) | (static_cast(reuseSmemKForV) << 57) | - (static_cast(uses2CtaMma) << 58); + (static_cast(uses2CtaMma) << 58) | (static_cast(sparseMla) << 59); } uint64_t hashID(KernelMeta const& kernelMeta) const { @@ -140,7 +142,7 @@ class TllmGenFmhaKernel { kernelMeta.mHeadDimPerCtaV, kernelMeta.mHeadDimQk, kernelMeta.mHeadDimV, kernelMeta.mTileSizeKv, kernelMeta.mNumTokensPerPage, kernelMeta.mMaxNumHeadsQPerKvInCta, kernelMeta.mReuseSmemKForV, - kernelMeta.m2CtaMma); + kernelMeta.m2CtaMma, kernelMeta.mSparseMla); } std::pair checkIfKernelExist(RunnerParams const& params) const { @@ -552,7 +554,8 @@ class TllmGenFmhaKernel { static_cast(selectKernelParams.mMultiCtasKvMode), selectKernelParams.mHeadDimPerCtaV, params.mHeadDimQk, params.mHeadDimV, selectKernelParams.mTileSizeKv, numTokensPerPage, maxNumHeadsQPerKvInCta, - selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma), + selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma, + /* sparseMla */ false), info); } diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 57adc57914..533b98c9e0 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -104,6 +104,8 @@ struct KernelParams { // The sequence lengths for K/V. Required by pagedKv kernels to avoid unnecessary computation // based on (ptrCumSeqLensKv[batchIdx + 1] - ptrCumSeqLensKv[batchIdx]). int32_t const* ptrSeqLensKv; + // The reserved memory buffer. + int32_t* ptrReservedMem; // The softmax stats buffer. float2* ptrSoftmaxStats; @@ -139,6 +141,8 @@ struct KernelParams { int64_t mNumHiddenEltsO; // The total number of pages in the paged-kv memory pool. int32_t mNumPagesInMemPool; + // The number of tokens per page (used if dynamic numTokensPerPage is enabled). + int32_t mNumTokensPerPageLog2; // The output scale for FP8 quantization. float mOutputScale; // The scaling factor for softmax (multiplied by log2 to use faster exp2). @@ -147,11 +151,15 @@ struct KernelParams { float mScaleSfKv; // The SF scale for O. float mScaleSfO; + // The reserved parameter. + float mReservedParam; // The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase // kernel when inflight batching is enabled in TRT-LLM. int32_t mStartTokenIdxSfO; // The sum of sequence lengths for Q and K/V. int32_t mSumOfSeqLensQ, mSumOfSeqLensKv; + // The sparseMla topK value. + int32_t mSparseMlaTopK; // The flag to use block sparse attention. bool mUseBlockSparseAttention; @@ -537,6 +545,8 @@ struct KernelParams { int32_t maxNumCtasQ, int32_t maxNumCtasKv) { // Create the return struct. KernelParams params; + // Memset the kernel parameters to 0. + memset(¶ms, 0, sizeof(KernelParams)); // Get the device pointers for TMA descriptors. auto [qPtr, kPtr, vPtr] = getDevicePtrs(options, get_size_in_bytes(kernelMeta.mDataTypeKv)); @@ -681,6 +691,16 @@ struct KernelParams { // Default 0 means that chunked attention is disabled. params.mChunkedAttentionSizeLog2 = 0; } + + // Compute the log of numTokensPerPage + int32_t numTokensPerPageLog2{-1}; + if (isPagedKv(options.mQkvLayout)) { + FLASHINFER_CHECK((options.mNumTokensPerPage & (options.mNumTokensPerPage - 1)) == 0, + "NumTokensPerPage must be power of 2"); + numTokensPerPageLog2 = (int)log2f((float)options.mNumTokensPerPage); + } + params.mNumTokensPerPageLog2 = numTokensPerPageLog2; + params.mMaxSeqLenQ = options.mMaxSeqLenQ; params.mMaxSeqLenKv = options.mMaxSeqLenKv; params.mMaxNumCtasQ = maxNumCtasQ; From 8d7d0bc3baedd35c797ffd919e92760de864ab3f Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 9 Nov 2025 15:14:06 -0800 Subject: [PATCH 044/130] refactor: remove MetaInfoHash class (#2064) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This class is not required after @jimmyzho 's refactor work in https://github.com/flashinfer-ai/flashinfer/pull/1967/files, and the only remaining pieces requiring its value is deepgemm (because of different artifact structure, deepgemm only have a kernel_map.json instead of header file). In this PR we remove the class `MetaInfoHash` to stop people further updating its content, and move the special case kernel map hash to deepgemm.py . ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @jimmyzho @cyx-6 ## Summary by CodeRabbit * **Refactor** * Simplified kernel metadata handling so kernel map objects initialize without external hash input * Standardized internal kernel validation using a fixed internal hash constant * Removed an obsolete public checksum data structure from the API surface * Reduced and cleaned up the public API to be leaner and clearer --- flashinfer/artifacts.py | 15 +-------------- flashinfer/deep_gemm.py | 12 +++++++----- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index f88328f4a1..60853ecd20 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -95,23 +95,10 @@ class ArtifactPath: "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" ) CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/" + # For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/" -@dataclass(frozen=True) -class MetaInfoHash: - DEEPGEMM: str = "f161e031826adb8c4f0d31ddbd2ed77e4909e4e43cdfc9728918162a62fcccfb" - TRTLLM_GEN_FMHA: str = ( - "2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a" - ) - TRTLLM_GEN_BMM: str = ( - "26c51b75921be90235d193675facdea5d8341c4c52c73bd0a7c8e787c0388beb" - ) - TRTLLM_GEN_GEMM: str = ( - "bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340" - ) - - class CheckSumHash: """ This class is used to store the checksums of the cubin files in artifactory. diff --git a/flashinfer/deep_gemm.py b/flashinfer/deep_gemm.py index 4da91750fd..c7e42494d4 100644 --- a/flashinfer/deep_gemm.py +++ b/flashinfer/deep_gemm.py @@ -41,7 +41,7 @@ import torch -from .artifacts import ArtifactPath, MetaInfoHash +from .artifacts import ArtifactPath from .cuda_utils import checkCudaErrors from .jit.cubin_loader import get_cubin from .jit.env import FLASHINFER_CUBIN_DIR @@ -1487,13 +1487,15 @@ def m_grouped_fp8_gemm_nt_masked( class KernelMap: - def __init__(self, sha256: str): - self.sha256 = sha256 + # Hash for kernel_map.json, updated when deepgemm cubins are republished + KERNEL_MAP_HASH = "f161e031826adb8c4f0d31ddbd2ed77e4909e4e43cdfc9728918162a62fcccfb" + + def __init__(self): self.indice = None def init_indices(self): indice_path = ArtifactPath.DEEPGEMM + "/" + "kernel_map.json" - assert get_cubin(indice_path, self.sha256), ( + assert get_cubin(indice_path, self.KERNEL_MAP_HASH), ( "cubin kernel map file not found, nor downloaded with matched sha256" ) path = FLASHINFER_CUBIN_DIR / indice_path @@ -1513,4 +1515,4 @@ def __getitem__(self, key): return self.indice[key] -KERNEL_MAP = KernelMap(MetaInfoHash.DEEPGEMM) +KERNEL_MAP = KernelMap() From f5a06a4ac69cc56a5e6c4535cf9d5d6970a7e5af Mon Sep 17 00:00:00 2001 From: FlashInfer Bot Date: Sun, 9 Nov 2025 19:57:41 -0800 Subject: [PATCH 045/130] chore: Update CODEOWNERS (#2067) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR updates the CODEOWNERS file based on git commit history analysis from the last 180 days. ## Changes - Updated `.github/CODEOWNERS` with current code ownership based on: - Commit frequency - File coverage - Commit recency ## How to Review 1. Review the changes to `.github/CODEOWNERS` 2. Verify that the assigned owners are appropriate for each module 3. Make manual adjustments if needed before merging ## Notes - This is an automated PR generated weekly - Minimum commits threshold: 1 - Analysis period: 180 days - Directory depth: 3 levels - Top N owners per module: 5 --- ๐Ÿค– This PR was automatically generated by the [update-codeowners workflow](.github/workflows/update-codeowners.yml) ## Summary by CodeRabbit * **Chores** * Updated code ownership and review responsibilities across project directories. Co-authored-by: flashinfer-bot Co-authored-by: Claude --- .github/CODEOWNERS | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 2e26c661f6..24f6838702 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,21 +3,21 @@ # Analysis period: 180 days # Minimum commits threshold: 1 -benchmarks/ @bkryu @cyx-6 @jiahanc @nv-yunzheq @kahyunnam +benchmarks/ @bkryu @cyx-6 @yzh119 @jiahanc @nv-yunzheq benchmarks/routines/ @bkryu @nv-yunzheq @cyx-6 @nvmbreughe @Anerudhan ci/ @cyx-6 @yzh119 @nvmbreughe ci/scripts/ @cyx-6 ci/scripts/jenkins/ @cyx-6 -csrc/ @wenscarl @yzh119 @cyx-6 @djmmoss @yongwww -csrc/fused_moe/ @yzh119 @yongwww @djmmoss @cyx-6 @wenscarl -csrc/fused_moe/cutlass_backend/ @yzh119 @yongwww @djmmoss @cyx-6 @wenscarl -csrc/nv_internal/ @wenscarl @djmmoss @cyx-6 @yzh119 @yongwww -csrc/nv_internal/cpp/ @wenscarl @yongwww @djmmoss @joker-eph @ttyio -csrc/nv_internal/include/ @wenscarl -csrc/nv_internal/tensorrt_llm/ @wenscarl @djmmoss @cyx-6 @yzh119 @yongwww +csrc/ @wenscarl @yzh119 @cyx-6 @djmmoss @nv-yunzheq +csrc/fused_moe/ @nv-yunzheq @yzh119 @yongwww @djmmoss @cyx-6 +csrc/fused_moe/cutlass_backend/ @nv-yunzheq @yzh119 @yongwww @djmmoss @cyx-6 +csrc/nv_internal/ @wenscarl @djmmoss @nv-yunzheq @yongwww @cyx-6 +csrc/nv_internal/cpp/ @wenscarl @bkryu @yongwww @djmmoss @joker-eph +csrc/nv_internal/include/ @wenscarl @nv-yunzheq +csrc/nv_internal/tensorrt_llm/ @wenscarl @djmmoss @nv-yunzheq @yongwww @cyx-6 csrc/xqa/ @cyx-6 @yzh119 docs/ @yzh119 @cyx-6 @wenscarl @nv-yunzheq @aleozlx -flashinfer/ @yzh119 @cyx-6 @wenscarl @nvmbreughe @yongwww +flashinfer/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @jiahanc flashinfer-cubin/ @yzh119 @cyx-6 flashinfer-cubin/flashinfer_cubin/ @yzh119 flashinfer-jit-cache/ @yzh119 @cyx-6 @@ -25,19 +25,21 @@ flashinfer-jit-cache/flashinfer_jit_cache/ @yzh119 flashinfer/comm/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @djmmoss flashinfer/cudnn/ @Anerudhan @yzh119 @cyx-6 @Anerudhan flashinfer/cute_dsl/ @yzh119 @kaixih @Amir-19 @aleozlx -flashinfer/fused_moe/ @djmmoss @yzh119 @cyx-6 @wenscarl @IwakuraRein -flashinfer/jit/ @yzh119 @cyx-6 @djmmoss @jiahanc @aleozlx +flashinfer/dsv3_ops/ @nvmbreughe +flashinfer/fused_moe/ @djmmoss @yzh119 @cyx-6 @jiahanc @wenscarl +flashinfer/gemm/ @nvmbreughe +flashinfer/jit/ @yzh119 @cyx-6 @jiahanc @nvmbreughe @nv-yunzheq flashinfer/jit/attention/ @yzh119 @cyx-6 @Anerudhan @joker-eph -flashinfer/jit/gemm/ @yzh119 +flashinfer/jit/gemm/ @yzh119 @nv-yunzheq @jiahanc flashinfer/logits_processor/ @cyx-6 @yzh119 flashinfer/profiler/ @cyx-6 -flashinfer/triton/ @cyx-6 @nvmbreughe @yzh119 +flashinfer/triton/ @nvmbreughe @cyx-6 flashinfer/tuning_configs/ @kaixih -include/ @yzh119 @wenscarl @kahyunnam @joker-eph @cyx-6 -include/flashinfer/ @yzh119 @wenscarl @kahyunnam @joker-eph @cyx-6 +include/ @yzh119 @jiahanc @nvmbreughe @bkryu @wenscarl +include/flashinfer/ @yzh119 @jiahanc @nvmbreughe @bkryu @wenscarl include/flashinfer/attention/ @yzh119 @kahyunnam @joker-eph include/flashinfer/comm/ @yongwww @nvmbreughe @djmmoss @yzh119 @cyx-6 -include/flashinfer/gemm/ @ttyio @yongwww @aleozlx -include/flashinfer/trtllm/ @joker-eph @aleozlx @yzh119 @cyx-6 @wenscarl +include/flashinfer/gemm/ @ttyio @yongwww @nvmbreughe @aleozlx +include/flashinfer/trtllm/ @jiahanc @joker-eph @aleozlx @yzh119 @cyx-6 profiler/ @cyx-6 scripts/ @yzh119 @nvmbreughe @dierksen @yongwww @bkryu From d42fb90ee59d269b77b218a93467f8af58756eba Mon Sep 17 00:00:00 2001 From: qsang-nv <200703406+qsang-nv@users.noreply.github.com> Date: Mon, 10 Nov 2025 23:07:31 +0800 Subject: [PATCH 046/130] feat: add xqa mla backend (#2053) --- csrc/xqa/mla_sm120.cu | 6 +- flashinfer/decode.py | 236 ++++++++++++++++--- tests/attention/test_trtllm_gen_mla.py | 94 +++++--- tests/attention/test_xqa_mla_batch_decode.py | 187 +++++++++++++++ 4 files changed, 454 insertions(+), 69 deletions(-) create mode 100644 tests/attention/test_xqa_mla_batch_decode.py diff --git a/csrc/xqa/mla_sm120.cu b/csrc/xqa/mla_sm120.cu index ffcf8ab3c5..30863edced 100644 --- a/csrc/xqa/mla_sm120.cu +++ b/csrc/xqa/mla_sm120.cu @@ -1790,17 +1790,17 @@ void launchMLAFlashInfer( uint32_t const nbVHeads = nbKHeads; uint32_t const nbQHeads = nbKHeads * headGrpSize; uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; - uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { + /*uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { float const factor = 4.f; return mha::min( mha::max( 1U, (uint32_t)round(multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)), divUp(maxSeqLen, tokensPerTile * 2)); - }(); + }();*/ // MLA disables multi-block mode for now // printf("nbSubSeqPerSeq = %u\n", nbSubSeqPerSeq); // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == // nbInputSeqSplit - dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimGrid{4 * inputSeqLen, 1, nbKHeads * batchSize}; dim3 const dimCta{warp_size * 4 * 3, 1, 1}; auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl); uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); diff --git a/flashinfer/decode.py b/flashinfer/decode.py index a85a1b846c..5db7d95a51 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -21,7 +21,7 @@ import torch -from .xqa import xqa +from .xqa import xqa, xqa_mla from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .jit import ( gen_batch_decode_mla_module, @@ -2437,11 +2437,9 @@ def xqa_batch_decode_with_kv_cache( kv_scale_value = bmm2_scale q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5) - query_new = query.unsqueeze(1).contiguous() - seq_lens_new = seq_lens.unsqueeze(1).contiguous() - sinks_new = ( - sinks.reshape(num_kv_heads, -1).contiguous() if sinks is not None else None - ) + query_new = query.unsqueeze(1) + seq_lens_new = seq_lens.unsqueeze(1) + sinks_new = sinks.reshape(num_kv_heads, -1) if sinks is not None else None # Ensure 4D output for xqa if out is None: @@ -2530,6 +2528,7 @@ def trtllm_batch_decode_with_kv_cache_mla( bmm2_scale_tensor: Optional[torch.Tensor] = None, sinks: Optional[List[torch.Tensor]] = None, enable_pdl: bool = None, + backend: str = "auto", ) -> torch.Tensor: """ Parameters: @@ -2548,6 +2547,173 @@ def trtllm_batch_decode_with_kv_cache_mla( bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in. bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input. sinks: additional value per head in the denominator of the softmax. + backend : str = "auto" + The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. + When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. + For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. + For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. + + Note: + In MLA, the actual BMM1 and BMM2 scales applied would be fused as: + bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) + bmm2_scale = v_scale * o_scale + or, + bmm1_scale_log2_tensor = [q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) * M_LOG2E] + bmm2_scale_tensor = [v_scale * o_scale] + + The two scale factors should be static constant for cuda graph capture. + Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. + + For static constant scale factors, the scale factors should be provided as float. + - (bmm1_scale, bmm2_scale) + For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. + - (bmm1_scale_log2_tensor, bmm2_scale_tensor) + - Currently, only fp8 tensor core operation supports this mode. + When both are provided, the dynamic scale factor tensors will be used. + """ + if backend == "auto": + backend = ( + "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" + ) + if backend == "xqa": + if ( + get_compute_capability(query.device)[0] != 12 + or query.dtype != torch.float8_e4m3fn + or kv_cache.dtype != torch.float8_e4m3fn + ): + raise ValueError( + f"XQA MLA only supports fp8 operation on SM120 GPUs, got {query.dtype} and {kv_cache.dtype}" + ) + if sinks is not None: + raise ValueError("XQA MLA does not support sinks") + if query.size(1) != 1: + raise ValueError( + f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}" + ) + return xqa_batch_decode_with_kv_cache_mla( + query, + kv_cache, + workspace_buffer, + qk_nope_head_dim, + kv_lora_rank, + qk_rope_head_dim, + block_tables, + seq_lens, + max_seq_len, + out, + bmm1_scale, + bmm2_scale, + sinks, + enable_pdl, + ) + elif backend == "trtllm-gen": + enable_pdl = ( + device_support_pdl(query.device) if enable_pdl is None else enable_pdl + ) + run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode + sm_count = get_device_sm_count(query.device) + + block_size = kv_cache.size(-2) + if ( + block_size != 32 and block_size != 64 + ): # todo(Yingyi): add support for more block sizes? + raise ValueError(f"Supported block_size are 32 and 64, got {block_size}") + + _check_trtllm_gen_mla_shape( + query, + kv_cache, + qk_nope_head_dim, + kv_lora_rank, + qk_rope_head_dim, + block_tables, + block_size, + ) + + if out is None: + out_shape = query.shape[:-1] + (kv_lora_rank,) + out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) + else: + batch_size, _, num_q_heads, _ = query.shape + check_shape_dtype_device( + out, + [batch_size, num_q_heads, kv_lora_rank], + torch.bfloat16, + query.device, + "out", + ) + + if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: + # dynamic scale factors + if ( + query.dtype != torch.float8_e4m3fn + or kv_cache.dtype != torch.float8_e4m3fn + ): + raise ValueError( + "Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation" + ) + + run_func( + out, + None, # fp4 output not supported in wrapper api yet. + query, + kv_cache, + kv_cache, + workspace_buffer, + block_tables, + seq_lens, + max_seq_len, + bmm1_scale, + bmm2_scale, + -1, # o_sf_scale + -1, # o_sf_vec_size + 0, # o_sf_start_index + -1, # window_left + sm_count, + enable_pdl, + workspace_buffer.numel() * workspace_buffer.element_size(), + sinks, + ) + + return out + else: + raise ValueError(f"Backend {backend} not supported") + + +def xqa_batch_decode_with_kv_cache_mla( + query: torch.Tensor, + kv_cache: torch.Tensor, + workspace_buffer: torch.Tensor, + qk_nope_head_dim: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + out: Optional[torch.Tensor] = None, + bmm1_scale: Optional[float] = 1.0, + bmm2_scale: Optional[float] = 1.0, + bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, + bmm2_scale_tensor: Optional[torch.Tensor] = None, + sinks: Optional[List[torch.Tensor]] = None, + enable_pdl: bool = None, +) -> torch.Tensor: + """ + Parameters: + query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. + kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache + workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use. + qk_nope_head_dim: qk_nope_head_dim, must be 128 + kv_lora_rank: kv_lora_rank, must be 512 + qk_rope_head_dim: qk_rope_head_dim, must be 64 + block_tables: page_table of kv cache, [batch_size, num_pages] + seq_lens: query_len + max_seq_len: max sequence length for kv_cache + out: output tensor, if not provided, will be allocated internally + bmm1_scale: fused scale for mla bmm1 input. + bmm2_scale: fused scale for mla bmm2 input. + bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in. + bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input. + sinks: additional value per head in the denominator of the softmax. Note: In MLA, the actual BMM1 and BMM2 scales applied would be fused as: @@ -2568,14 +2734,20 @@ def trtllm_batch_decode_with_kv_cache_mla( When both are provided, the dynamic scale factor tensors will be used. """ enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl - run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode sm_count = get_device_sm_count(query.device) block_size = kv_cache.size(-2) - if ( - block_size != 32 and block_size != 64 - ): # todo(Yingyi): add support for more block sizes? - raise ValueError(f"Supported block_size are 32 and 64, got {block_size}") + q_len_per_request = query.size(1) + if q_len_per_request != 1: + raise ValueError( + f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}" + ) + if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: + raise ValueError( + f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}" + ) + if sinks is not None: + raise ValueError("XQA MLA does not support sinks") _check_trtllm_gen_mla_shape( query, @@ -2600,33 +2772,27 @@ def trtllm_batch_decode_with_kv_cache_mla( "out", ) - if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: - # dynamic scale factors - if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: - raise ValueError( - "Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation" - ) + workspace_u8 = workspace_buffer.view(torch.uint8) + semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore + scratch = workspace_u8[8 * 1024 * 1024 :] + # This can not be replaced by kv_cache.transpose(1, 2) because the stride is not the same + kv_cache_new = kv_cache.squeeze(1).unsqueeze(2) + seq_lens_new = seq_lens.unsqueeze(1) - run_func( - out, - None, # fp4 output not supported in wrapper api yet. + xqa_mla( query, - kv_cache, - kv_cache, - workspace_buffer, + kv_cache_new, + kv_cache_new, block_tables, - seq_lens, - max_seq_len, - bmm1_scale, - bmm2_scale, - -1, # o_sf_scale - -1, # o_sf_vec_size - 0, # o_sf_start_index - -1, # window_left - sm_count, - enable_pdl, - workspace_buffer.numel() * workspace_buffer.element_size(), - sinks, + seq_lens_new, + out, + scratch, + semaphore, + block_size, + q_scale=bmm1_scale, + kv_scale=bmm2_scale, + sm_count=sm_count, + enable_pdl=enable_pdl, ) return out diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index db6d827d67..999eda2a8a 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -23,6 +23,7 @@ ) # todo(Yingyi): verify larger q_len_per_request @pytest.mark.parametrize("dynamic_scale", [False]) @pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) def test_trtllm_batch_decode_mla( batch_size: int, scale: float, @@ -31,10 +32,19 @@ def test_trtllm_batch_decode_mla( q_len_per_request: int, dynamic_scale: bool, enable_pdl: bool, + backend: str, ): compute_capability = get_compute_capability(torch.device(device="cuda")) - if compute_capability[0] != 10: - pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + if backend == "xqa": + if compute_capability[0] != 12: + pytest.skip("XQA MLA only supports SM120 GPUs") + if q_len_per_request != 1 or dtype != torch.float8_e4m3fn: + pytest.skip( + "XQA MLA only supports q_len_per_request == 1 and dtype == torch.float8_e4m3fn" + ) + if backend == "trtllm-gen": + if compute_capability[0] != 10: + pytest.skip("TRTLLM-GEN MLA only supports SM100 and SM103 GPUs") if dynamic_scale and dtype != torch.float8_e4m3fn: pytest.skip("Dynamic scale is not supported for non-fp8 dtype") @@ -72,7 +82,7 @@ def test_trtllm_batch_decode_mla( max_num_blocks_per_seq = blocks_per_seq.max().item() # Generate random but unique block IDs for all sequences - total_blocks_needed = sum(blocks_per_seq) + total_blocks_needed = int(blocks_per_seq.sum().item()) all_block_ids = torch.randperm( total_blocks_needed, device=device ) # Random permutation @@ -86,7 +96,7 @@ def test_trtllm_batch_decode_mla( # Populate block tables and track block assignments block_id = 0 for i in range(batch_size): - num_blocks_needed = blocks_per_seq[i] + num_blocks_needed = int(blocks_per_seq[i].item()) block_tables[i, :num_blocks_needed] = all_block_ids[ block_id : block_id + num_blocks_needed ] @@ -144,6 +154,7 @@ def test_trtllm_batch_decode_mla( bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, bmm2_scale_tensor=bmm2_scale_tensor, enable_pdl=enable_pdl, + backend=backend, ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future @@ -198,31 +209,52 @@ def test_trtllm_batch_decode_mla( o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) - # check is nan - assert not torch.isnan(o_ref).any(), "o_ref is nan" - assert not torch.isnan(output).any(), "output is nan" + if backend == "trtllm-gen": + # check is nan + assert not torch.isnan(o_ref).any(), "o_ref is nan" + assert not torch.isnan(output).any(), "output is nan" - if dtype == torch.float8_e4m3fn: - try: - torch.testing.assert_close( - output, - o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), - rtol=1e-1, - atol=1e-1, - ) # todo: do reference with normal attention? - except AssertionError as e: - print("output:", output) - print("o_ref:", o_ref) - raise e - else: - try: - torch.testing.assert_close( - output, - o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), - rtol=1e-2, - atol=1e-2, - ) - except AssertionError as e: - print("output:", output) - print("o_ref:", o_ref) - raise e + if dtype == torch.float8_e4m3fn: + try: + torch.testing.assert_close( + output, + o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), + rtol=1e-1, + atol=1e-1, + ) # todo: do reference with normal attention? + except AssertionError as e: + print("output:", output) + print("o_ref:", o_ref) + raise e + else: + try: + torch.testing.assert_close( + output, + o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), + rtol=1e-2, + atol=1e-2, + ) + except AssertionError as e: + print("output:", output) + print("o_ref:", o_ref) + raise e + elif backend == "xqa": + atol = 0.05 + rtol = 0.05 + + diff_abs = torch.abs( + o_ref.view(batch_size, q_len_per_request, num_q_heads, -1) - output + ) + diff_rel = diff_abs / ( + torch.abs(o_ref.view(batch_size, q_len_per_request, num_q_heads, -1)) + 1e-8 + ) + + within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) + + pass_ratio = within_tolerance.float().mean().item() + + required_ratio = 0.95 + assert pass_ratio >= required_ratio, ( + f"Total {o_ref.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, " + f"require at least {required_ratio:.1%}" + ) diff --git a/tests/attention/test_xqa_mla_batch_decode.py b/tests/attention/test_xqa_mla_batch_decode.py new file mode 100644 index 0000000000..4d3abb52e1 --- /dev/null +++ b/tests/attention/test_xqa_mla_batch_decode.py @@ -0,0 +1,187 @@ +import pytest +import torch + +import flashinfer +from flashinfer.utils import get_compute_capability + +global_workspace_buffer = None # can.be empty initialized +global_xqa_workspace_buffer = None # must be zero initialized +workspace_size = 128 * 1024 * 1024 + + +@pytest.mark.parametrize( + "batch_size", + [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], +) +@pytest.mark.parametrize("scale", [1.0, 0.5]) +@pytest.mark.parametrize("page_size", [32, 64, 128]) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +def test_xqa_mla_batch_decode( + batch_size: int, + scale: float, + page_size: int, + enable_pdl: bool, +): + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] != 12: + pytest.skip("These tests are only guaranteed to work on SM120 GPUs.") + + torch.manual_seed(42) + dtype = torch.float8_e4m3fn + q_len_per_request = 1 + device = "cuda:0" + + # Fixed max sequence length + max_seq_len = 1024 + + # Deepseek attention config (decode-MLA) + num_q_heads = 128 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + kv_lora_rank = 512 + + # Initialize tensors + query = torch.randn( + batch_size, + q_len_per_request, + num_q_heads, + kv_lora_rank + qk_rope_head_dim, + device=device, + ).to(dtype) + + num_tokens = max_seq_len * batch_size + num_blocks = (num_tokens + page_size - 1) // page_size + + # Sequence lengths and block tables + seq_lens = [torch.randint(1, max_seq_len, (1,)).item() for _ in range(batch_size)] + seq_lens[-1] = max_seq_len + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) + + blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size + max_num_blocks_per_seq = blocks_per_seq.max().item() + + # Generate random but unique block IDs for all sequences + total_blocks_needed = int(blocks_per_seq.sum().item()) + all_block_ids = torch.randperm( + total_blocks_needed, device=device + ) # Random permutation + + # Generate unique block IDs for all sequences + block_id = 0 + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device + ) + + # Populate block tables and track block assignments + block_id = 0 + for i in range(batch_size): + num_blocks_needed = int(blocks_per_seq[i].item()) + block_tables[i, :num_blocks_needed] = all_block_ids[ + block_id : block_id + num_blocks_needed + ] + block_id += num_blocks_needed + + # Create interleaved KV cache + # Allocate more than needed blocks, block_id is just enough, to mimick real-world cases + kv_cache = torch.randn( + size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device + ).to(dtype) + # (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim) + + global global_workspace_buffer, global_xqa_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + workspace_size, dtype=torch.int8, device=device + ) + if global_xqa_workspace_buffer is None: + global_xqa_workspace_buffer = torch.zeros( + workspace_size, dtype=torch.int8, device=device + ) + workspace_buffer = global_xqa_workspace_buffer + workspace_buffer_ref = global_workspace_buffer + + # Run decode-MLA + output = flashinfer.decode.xqa_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens_tensor, + max_seq_len=max_seq_len, + bmm1_scale=scale / ((128 + 64) ** 0.5), + bmm2_scale=1.0, + enable_pdl=enable_pdl, + ) + + # Run reference attention and align output + sm_scale = scale / ( + (128 + 64) ** 0.5 + ) # use head dimension before matrix absorption + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + workspace_buffer_ref, + backend="fa2", + ) + + if dtype == torch.float8_e4m3fn: + # convert query and kv_cache to bfloat16 + query = query.to(torch.bfloat16) + kv_cache = kv_cache.to(torch.bfloat16) + + q_indptr = ( + torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) + * q_len_per_request + ) + kv_indptr = torch.zeros_like(q_indptr) + kv_indptr[1:] = torch.cumsum(blocks_per_seq, dim=0) + kv_indices = all_block_ids.int() + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + seq_lens_tensor, + num_q_heads, + kv_lora_rank, + qk_rope_head_dim, + page_size, + True, + sm_scale, + query.dtype, + kv_cache.dtype, + ) + q_nope = query[..., :kv_lora_rank].view( + batch_size * q_len_per_request, num_q_heads, kv_lora_rank + ) + q_pe = query[..., kv_lora_rank:].view( + batch_size * q_len_per_request, num_q_heads, qk_rope_head_dim + ) + + # todo: fix kv_cache + ckv = kv_cache[..., :kv_lora_rank] + kpe = kv_cache[..., kv_lora_rank:] + + o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) + + atol = 0.05 + rtol = 0.05 + + diff_abs = torch.abs( + o_ref.view(batch_size, q_len_per_request, num_q_heads, -1) - output + ) + diff_rel = diff_abs / ( + torch.abs(o_ref.view(batch_size, q_len_per_request, num_q_heads, -1)) + 1e-8 + ) + + within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) + + pass_ratio = within_tolerance.float().mean().item() + + required_ratio = 0.95 + assert pass_ratio >= required_ratio, ( + f"Total {o_ref.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, " + f"require at least {required_ratio:.1%}" + ) From fbdb4396a91158d43144b82c3db8594fccf33341 Mon Sep 17 00:00:00 2001 From: Lain Date: Mon, 10 Nov 2025 18:19:50 -0800 Subject: [PATCH 047/130] Enable renormalize(naive) routing for fp8 per-tensor (#2030) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Disable expert weights in the FC1 except for Llama routing. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Re-enabled Renormalize routing that was previously blocked. * Made token_scales available for Llama4 routing. * Corrected GEMM1 input so the proper data source is used during MoE processing. * **Tests** * Added FP8PerTensorMoe to test parameterization. * Expanded Renormalize and DeepSeekV3 test coverage and removed related skips. Signed-off-by: siyuanf --- csrc/trtllm_fused_moe_kernel_launcher.cu | 3 +++ csrc/trtllm_fused_moe_runner.cu | 2 +- include/flashinfer/trtllm/fused_moe/runner.h | 4 ++++ tests/moe/test_trtllm_gen_fused_moe.py | 22 +++++++++++++++++--- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 0688c1e97d..f3c45e2ec0 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -584,6 +584,9 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); workspace.expert_weights = expert_weights.data_ptr(); + if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { + workspace.token_scales = expert_weights.data_ptr(); // Consumed by permuteGemm1 kernel + } } void check_moe() const override { diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 21a2cad4b5..b5ff5757c9 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -518,7 +518,7 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d auto const& config = mPassingConfigs[configIndex]; mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, - args.gemm1_weights_scale, workspace.expert_weights, args.output1_scales_scalar, + args.gemm1_weights_scale, workspace.token_scales, args.output1_scales_scalar, args.output1_scales_gate_scalar, args.gemm1_bias, args.gemm1_alpha, args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output, workspace.gemm1_output_scale, args.top_k, args.hidden_size, diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 8d99902d67..3941a23249 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -305,7 +305,11 @@ struct MoEWorkspace { int32_t* expanded_idx_to_permuted_idx = nullptr; int32_t* permuted_idx_to_expanded_idx = nullptr; int32_t* permuted_idx_to_token_idx = nullptr; + + // consumed by finalize kernel void* expert_weights = nullptr; // [num_tokens, top_k] in bfloat16 = mDtypeExpW + // consumed by permuteGemm1 kernel + void* token_scales = nullptr; int32_t* cta_idx_xy_to_batch_idx = nullptr; int32_t* cta_idx_xy_to_mn_limit = nullptr; diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 1427b15245..4706b4c87a 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2275,6 +2275,7 @@ def run_moe_test( [ pytest.param(BF16Moe(), id="BF16xBF16"), pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), + pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), @@ -2293,7 +2294,12 @@ def run_moe_test( "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe, BF16Moe], + "compatible_moe_impls": [ + FP8PerTensorMoe, + FP8BlockScaleMoe, + FP4Moe, + BF16Moe, + ], "compatible_intermediate_size": [384, 768, 1024], }, id="Qwen3", @@ -2308,7 +2314,12 @@ def run_moe_test( "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe, BF16Moe], + "compatible_moe_impls": [ + FP8PerTensorMoe, + FP8BlockScaleMoe, + FP4Moe, + BF16Moe, + ], "compatible_intermediate_size": [384, 1024], }, id="Renorm", @@ -2323,7 +2334,12 @@ def run_moe_test( "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe, BF16Moe], + "compatible_moe_impls": [ + FP8PerTensorMoe, + FP8BlockScaleMoe, + FP4Moe, + BF16Moe, + ], "compatible_intermediate_size": [512], }, id="Qwen3_next", From 11177e8d0ebf32599786afcdc46211855d37074b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 11 Nov 2025 07:55:13 -0800 Subject: [PATCH 048/130] unittest: improve the efficiency of xqa unittests (#2075) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description The implementation of xqa unittests are sub-optimal: we use lots of cpu index calculation and slicing operations. This PR refactors the unittest to use tensor operations as much as possible and remove redundant logics. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @qsang-nv @jiahanc @bkryu ## Summary by CodeRabbit * **Tests** * Refactored internal test infrastructure for attention operations with vectorized batch processing, improving test efficiency and GPU utilization. * **Refactor** * Optimized cache assembly logic and data handling patterns in test utilities for improved performance. --- tests/attention/test_xqa.py | 322 ++++++++++------------- tests/attention/test_xqa_batch_decode.py | 123 +++++---- 2 files changed, 213 insertions(+), 232 deletions(-) diff --git a/tests/attention/test_xqa.py b/tests/attention/test_xqa.py index 5701bdc1b8..172135c571 100644 --- a/tests/attention/test_xqa.py +++ b/tests/attention/test_xqa.py @@ -31,38 +31,10 @@ def div_up(a, b): beam_width = 1 -class CacheSeq: - def __init__( - self, - pool: torch.Tensor, - page_indices: torch.Tensor, - nb_heads: int, - idx_head: int, - tokens_per_page: int = 32, - kv_layout: str = "NHD", - ): - self.pool = pool - self.page_indices = page_indices - self.nb_heads = nb_heads - self.idx_head = idx_head - self.tokens_per_page = tokens_per_page - self.kv_layout = kv_layout - - def __getitem__(self, i: int) -> torch.Tensor: - page_idx = self.page_indices[i // self.tokens_per_page].to(torch.int32) - token_in_page = i % self.tokens_per_page - if self.kv_layout == "NHD": - # NHD layout: [page_idx, token_in_page, idx_head, :] - return self.pool[page_idx, token_in_page, self.idx_head, :] - else: # HND - # HND layout: [page_idx, idx_head, token_in_page, :] - return self.pool[page_idx, self.idx_head, token_in_page, :] - - def ref_attention( q, - k_cache_seq, - v_cache_seq, + k_cache, # Changed: now takes full tensor [seq_len, dim] + v_cache, # Changed: now takes full tensor [seq_len, dim] seq_len, q_scale, kv_scale, @@ -89,18 +61,12 @@ def ref_attention( q_f32 = q.to(torch.float32) # [head_grp_size, valid_elems_per_head] - k_cache_f32 = torch.zeros( - seq_len, valid_elems_per_head, dtype=torch.float32, device="cuda" - ) - # V cache: load only valid_elems_per_v_head dimensions - v_cache_f32 = torch.zeros( - seq_len, valid_elems_per_v_head, dtype=torch.float32, device="cuda" - ) - - for j in range(seq_len): - k_cache_f32[j] = k_cache_seq[j].to(torch.float32) - # For MLA: V cache storage is 576 but only first 512 elements are valid - v_cache_f32[j] = v_cache_seq[j][:valid_elems_per_v_head].to(torch.float32) + # Directly use the pre-assembled cache tensors + k_cache_f32 = k_cache[:seq_len].to(torch.float32) # [seq_len, valid_elems_per_head] + # For MLA: V cache storage is 576 but only first 512 elements are valid + v_cache_f32 = v_cache[:seq_len, :valid_elems_per_v_head].to( + torch.float32 + ) # [seq_len, valid_elems_per_v_head] # q_f32: [head_grp_size, valid_elems_per_head] # k_cache_f32: [seq_len, valid_elems_per_head] @@ -223,12 +189,12 @@ def test_xqa( ) q_heads.normal_(0, 1) if use_attention_sinks: - attention_sinks = torch.zeros( - nb_k_heads, head_grp_size, dtype=torch.float32, device="cuda" + # Vectorized creation of attention_sinks + j_indices = torch.arange(head_grp_size, device="cuda") + attention_sinks = 2.0 + (j_indices % 4).float() + attention_sinks = ( + attention_sinks.unsqueeze(0).expand(nb_k_heads, head_grp_size).contiguous() ) - for i in range(nb_k_heads): - for j in range(head_grp_size): - attention_sinks[i, j] = 2.0 + float(j % 4) else: attention_sinks = None if use_sliding_window: @@ -287,65 +253,63 @@ def test_xqa( # and prevent overflow during computation. The factor 4.0 is chosen empirically. cache_k_heads /= 4.0 cache_v_heads /= 4.0 - page_list_arg = torch.zeros( - batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda" + # Vectorized page list initialization + total_pages = batch_size * nb_pages_per_seq + page_list_arg = torch.arange(total_pages, dtype=torch.int32, device="cuda").view( + batch_size, nb_pages_per_seq ) - # Initialize page list sequentially - page_idx = 0 - for batch in range(batch_size): - for page in range(nb_pages_per_seq): - page_list_arg[batch, page] = page_idx - page_idx += 1 - + # Shuffle page indices flattened = page_list_arg.flatten() - indices = torch.randperm(flattened.numel()) + indices = torch.randperm(flattened.numel(), device="cuda") shuffled_flat = flattened[indices] - page_list_arg = shuffled_flat.view(page_list_arg.shape) - - def cache_head_at( - batch, - is_k, - idx_kv_head, - pos, - cache_k_heads, - cache_v_heads, - page_list, - beam_width, - nb_k_heads, - tokens_per_page, - kv_layout, - ): - # K and V share page indices - page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) - token_in_page = pos % tokens_per_page - - cache = cache_k_heads if is_k else cache_v_heads - if kv_layout == "NHD": - # NHD layout: [page_idx, token_in_page, idx_kv_head, :] - return cache[page_idx, token_in_page, idx_kv_head, :] - else: # HND - # HND layout: [page_idx, idx_kv_head, token_in_page, :] - return cache[page_idx, idx_kv_head, token_in_page, :] - - for batch in range(batch_size): - for kv in range(2): - for idx_kv_head in range(nb_k_heads): - for pos in range(seq_len, max_seq_len): - cache_head = cache_head_at( - batch, - kv == 0, - idx_kv_head, - pos, - cache_k_heads, - cache_v_heads, - page_list_arg, - beam_width, - nb_k_heads, - tokens_per_page, - kv_layout, + page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq) + + # Vectorized zeroing of unused cache positions using advanced indexing + if seq_len < max_seq_len: + # Collect all (page_id, token_pos) pairs that need to be zeroed across all batches + start_page = seq_len // tokens_per_page + end_page = nb_pages_per_seq + + if start_page < end_page: + # Get all page IDs that need partial/full zeroing: [batch_size, num_pages_to_zero] + pages_to_zero = page_list_arg[ + :, start_page:end_page + ] # [batch_size, num_pages_to_zero] + + # For the first page (start_page), zero from [seq_len % tokens_per_page, tokens_per_page) + # For subsequent pages, zero entirely [0, tokens_per_page) + first_page_ids = pages_to_zero[:, 0] # [batch_size] + token_start_in_first_page = seq_len % tokens_per_page + + if token_start_in_first_page > 0: + # Zero partial first page for all batches at once + if kv_layout == "NHD": + cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = ( + 0.0 + ) + cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = ( + 0.0 + ) + else: # HND + cache_k_heads[first_page_ids, :, token_start_in_first_page:, :] = ( + 0.0 + ) + cache_v_heads[first_page_ids, :, token_start_in_first_page:, :] = ( + 0.0 ) - cache_head.fill_(0.0) + + # Zero all subsequent full pages (if any) for all batches at once + if pages_to_zero.shape[1] > 1: + remaining_page_ids = pages_to_zero[ + :, 1: + ].flatten() # Flatten all remaining pages + if kv_layout == "NHD": + cache_k_heads[remaining_page_ids, :, :, :] = 0.0 + cache_v_heads[remaining_page_ids, :, :, :] = 0.0 + else: # HND + cache_k_heads[remaining_page_ids, :, :, :] = 0.0 + cache_v_heads[remaining_page_ids, :, :, :] = 0.0 seq_len_list = torch.zeros( batch_size, beam_width, dtype=torch.uint32, device="cuda" @@ -385,30 +349,36 @@ def cache_head_at( for req in range(batch_size): for b in range(beam_width): for idx_k_head in range(nb_k_heads): - # K and V use separate pools but share page indices - k_cache_seq = CacheSeq( - pool=cache_k_heads, - page_indices=page_list_arg[req], - nb_heads=nb_k_heads, - idx_head=idx_k_head, - tokens_per_page=tokens_per_page, - kv_layout=kv_layout, - ) - v_cache_seq = CacheSeq( - pool=cache_v_heads, - page_indices=page_list_arg[req], - nb_heads=nb_k_heads, - idx_head=idx_k_head, - tokens_per_page=tokens_per_page, - kv_layout=kv_layout, - ) + # Assemble contiguous K/V cache from paged memory using advanced indexing + num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page + pages = page_list_arg[req, :num_pages] # [num_pages] + + # Gather all pages at once + if kv_layout == "NHD": + # [num_pages, tokens_per_page, nb_k_heads, head_dim] + k_pages = cache_k_heads[ + pages, :, idx_k_head, : + ] # [num_pages, tokens_per_page, head_dim] + v_pages = cache_v_heads[pages, :, idx_k_head, :] + else: # HND + # [num_pages, nb_k_heads, tokens_per_page, head_dim] + k_pages = cache_k_heads[ + pages, idx_k_head, :, : + ] # [num_pages, tokens_per_page, head_dim] + v_pages = cache_v_heads[pages, idx_k_head, :, :] + + # Reshape to contiguous sequence + k_cache = k_pages.reshape( + -1, valid_elems_per_head + ) # [num_pages*tokens_per_page, head_dim] + v_cache = v_pages.reshape(-1, valid_elems_per_head) ref_output = ref_attention( q=q_heads[req][b][ idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size ], - k_cache_seq=k_cache_seq, - v_cache_seq=v_cache_seq, + k_cache=k_cache, + v_cache=v_cache, seq_len=seq_len, q_scale=q_scale, kv_scale=kv_cache_scale, @@ -520,59 +490,41 @@ def test_xqa_mla( cache_k_heads /= 4.0 cache_v_heads /= 4.0 - page_list_arg = torch.zeros( - batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda" + # Vectorized page list initialization + total_pages = batch_size * nb_pages_per_seq + page_list_arg = torch.arange(total_pages, dtype=torch.int32, device="cuda").view( + batch_size, nb_pages_per_seq ) - # Initialize page list sequentially - page_idx = 0 - for batch in range(batch_size): - for page in range(nb_pages_per_seq): - page_list_arg[batch, page] = page_idx - page_idx += 1 - + # Shuffle page indices flattened = page_list_arg.flatten() - indices = torch.randperm(flattened.numel()) + indices = torch.randperm(flattened.numel(), device="cuda") shuffled_flat = flattened[indices] - page_list_arg = shuffled_flat.view(page_list_arg.shape) - - def cache_head_at( - batch, - is_k, - idx_kv_head, - pos, - cache_k_heads, - cache_v_heads, - page_list, - beam_width, - nb_k_heads, - tokens_per_page, - ): - # K and V share page indices - page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) - token_in_page = pos % tokens_per_page - - # NHD layout: [page_idx, token_in_page, idx_kv_head, :] - cache = cache_k_heads if is_k else cache_v_heads - return cache[page_idx, token_in_page, idx_kv_head, :] - - for batch in range(batch_size): - for kv in range(2): - for idx_kv_head in range(nb_k_heads): - for pos in range(seq_len, max_seq_len): - cache_head = cache_head_at( - batch, - kv == 0, - idx_kv_head, - pos, - cache_k_heads, - cache_v_heads, - page_list_arg, - beam_width, - nb_k_heads, - tokens_per_page, - ) - cache_head.fill_(0.0) + page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq) + + # Vectorized zeroing of unused cache positions (NHD layout only for MLA) + if seq_len < max_seq_len: + start_page = seq_len // tokens_per_page + end_page = nb_pages_per_seq + + if start_page < end_page: + pages_to_zero = page_list_arg[ + :, start_page:end_page + ] # [batch_size, num_pages_to_zero] + + first_page_ids = pages_to_zero[:, 0] # [batch_size] + token_start_in_first_page = seq_len % tokens_per_page + + if token_start_in_first_page > 0: + # Zero partial first page for all batches at once (NHD layout) + cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0 + cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0 + + # Zero all subsequent full pages (if any) for all batches at once + if pages_to_zero.shape[1] > 1: + remaining_page_ids = pages_to_zero[:, 1:].flatten() + cache_k_heads[remaining_page_ids, :, :, :] = 0.0 + cache_v_heads[remaining_page_ids, :, :, :] = 0.0 seq_len_list = torch.zeros( batch_size, beam_width, dtype=torch.uint32, device="cuda" @@ -608,28 +560,26 @@ def cache_head_at( for req in range(batch_size): for b in range(beam_width): for idx_k_head in range(nb_k_heads): - # K and V use separate pools but share page indices - k_cache_seq = CacheSeq( - pool=cache_k_heads, - page_indices=page_list_arg[req], - nb_heads=nb_k_heads, - idx_head=idx_k_head, - tokens_per_page=tokens_per_page, - ) - v_cache_seq = CacheSeq( - pool=cache_v_heads, - page_indices=page_list_arg[req], - nb_heads=nb_k_heads, - idx_head=idx_k_head, - tokens_per_page=tokens_per_page, - ) + # Assemble contiguous K/V cache from paged memory using advanced indexing + num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page + pages = page_list_arg[req, :num_pages] # [num_pages] + + # NHD layout: [num_pages, tokens_per_page, nb_k_heads, head_dim] + k_pages = cache_k_heads[ + pages, :, idx_k_head, : + ] # [num_pages, tokens_per_page, head_dim] + v_pages = cache_v_heads[pages, :, idx_k_head, :] + + # Reshape to contiguous sequence + k_cache = k_pages.reshape(-1, valid_elems_per_head_qk) + v_cache = v_pages.reshape(-1, valid_elems_per_head_qk) ref_output = ref_attention( q=q_heads[req][b][ idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size ], - k_cache_seq=k_cache_seq, - v_cache_seq=v_cache_seq, + k_cache=k_cache, + v_cache=v_cache, seq_len=seq_len, q_scale=q_scale * math.sqrt(576), kv_scale=kv_cache_scale, diff --git a/tests/attention/test_xqa_batch_decode.py b/tests/attention/test_xqa_batch_decode.py index fbeac45354..7a2bd3356a 100644 --- a/tests/attention/test_xqa_batch_decode.py +++ b/tests/attention/test_xqa_batch_decode.py @@ -143,28 +143,33 @@ def create_kv_cache( def create_page_table(batch_size, seq_lens, page_size): + # Ensure seq_lens is on GPU and calculate page_per_seq on GPU + seq_lens = seq_lens.to(GPU_DEVICE) page_per_seq = (seq_lens + page_size - 1) // page_size max_num_pages_per_seq = torch.max(page_per_seq).item() - # Generate random but unique page IDs for all sequences + # Generate sequential page IDs total_pages_needed = torch.sum(page_per_seq).item() - all_page_ids = torch.randperm( + all_page_ids = torch.arange( total_pages_needed, dtype=torch.int32, device=GPU_DEVICE ) - # Generate unique page IDs for all sequences - page_tables = torch.zeros( - (batch_size, max_num_pages_per_seq), dtype=torch.int32, device=GPU_DEVICE + # Use cumsum to create page offsets for each sequence + page_offsets = torch.cat( + [ + torch.tensor([0], device=GPU_DEVICE, dtype=torch.int32), + torch.cumsum(page_per_seq[:-1], dim=0, dtype=torch.int32), + ] ) - # Populate page tables and track page assignments - page_id = 0 - for i in range(batch_size): - num_pages_needed = page_per_seq[i] - page_tables[i, :num_pages_needed] = all_page_ids[ - page_id : page_id + num_pages_needed - ] - page_id += num_pages_needed + # Create page tables using broadcasting + page_idx_range = torch.arange( + max_num_pages_per_seq, device=GPU_DEVICE, dtype=torch.int32 + ).unsqueeze(0) + page_tables = ( + page_offsets.unsqueeze(1) + page_idx_range + ) # [batch_size, max_num_pages_per_seq] + return page_tables, all_page_ids, page_per_seq @@ -179,43 +184,69 @@ def flatten_paged_kv( """Build flat K/V and token-level indptr from paged KV cache and page table. Supports both NHD and HND layouts. + Optimized to avoid loops using vectorized operations. """ device = ref_kv_cache.device batch_size = int(page_table.shape[0]) - # Move loop-control tensors to CPU to avoid GPU sync in loops - page_table_cpu = page_table.cpu() - seq_lens_cpu = seq_lens.cpu() - kv_last_page_len_cpu = kv_last_page_len.cpu() - page_per_seq = (seq_lens_cpu + page_size - 1) // page_size - k_list = [] - v_list = [] - for i in range(batch_size): - pages_i = int(page_per_seq[i].item()) - last_len_i = int(kv_last_page_len_cpu[i].item()) - for j in range(pages_i): - page_id = int(page_table_cpu[i, j].item()) - if kv_layout == "NHD": - # NHD: [page_id, 0/1, page_size, num_heads, head_dim] - k_page = ref_kv_cache[page_id, 0] # [page_size, num_heads, head_dim] - v_page = ref_kv_cache[page_id, 1] - if j == pages_i - 1: - k_page = k_page[:last_len_i, :, :] - v_page = v_page[:last_len_i, :, :] - else: # HND - # HND: [page_id, 0/1, num_heads, page_size, head_dim] - k_page = ref_kv_cache[page_id, 0] # [num_heads, page_size, head_dim] - v_page = ref_kv_cache[page_id, 1] - if j == pages_i - 1: - k_page = k_page[:, :last_len_i, :] - v_page = v_page[:, :last_len_i, :] - # Transpose to NHD: [num_heads, page_size, head_dim] -> [page_size, num_heads, head_dim] - k_page = k_page.transpose(0, 1) - v_page = v_page.transpose(0, 1) - k_list.append(k_page) - v_list.append(v_page) - k_flat = torch.cat(k_list, dim=0) - v_flat = torch.cat(v_list, dim=0) + # Calculate number of pages per sequence + page_per_seq = (seq_lens + page_size - 1) // page_size + max_pages = int(page_per_seq.max().item()) + + # Gather all pages at once using advanced indexing + # page_table shape: [batch_size, max_pages] + if kv_layout == "NHD": + # ref_kv_cache: [num_pages_total, 2, page_size, num_heads, head_dim] + # Gather: [batch_size, max_pages, page_size, num_heads, head_dim] + k_pages = ref_kv_cache[ + page_table, 0 + ] # [batch_size, max_pages, page_size, num_heads, head_dim] + v_pages = ref_kv_cache[page_table, 1] + else: # HND + # ref_kv_cache: [num_pages_total, 2, num_heads, page_size, head_dim] + # Gather: [batch_size, max_pages, num_heads, page_size, head_dim] + k_pages = ref_kv_cache[ + page_table, 0 + ] # [batch_size, max_pages, num_heads, page_size, head_dim] + v_pages = ref_kv_cache[page_table, 1] + # Transpose to NHD: [batch_size, max_pages, num_heads, page_size, head_dim] -> [batch_size, max_pages, page_size, num_heads, head_dim] + k_pages = k_pages.transpose(2, 3) + v_pages = v_pages.transpose(2, 3) + + # Reshape to [batch_size, max_pages * page_size, num_heads, head_dim] + num_heads = k_pages.shape[-2] + head_dim = k_pages.shape[-1] + k_pages = k_pages.reshape(batch_size, max_pages * page_size, num_heads, head_dim) + v_pages = v_pages.reshape(batch_size, max_pages * page_size, num_heads, head_dim) + + # Create token indices for each sequence using vectorized operations + # For each batch, we need to extract [:seq_len] tokens + max_seq_len = seq_lens.max().item() + token_idx = torch.arange(max_seq_len, device=device, dtype=torch.int32).unsqueeze( + 0 + ) # [1, max_seq_len] + token_mask = token_idx < seq_lens.unsqueeze(1) # [batch_size, max_seq_len] + + # Gather valid tokens for all sequences at once + # Expand k_pages and v_pages to max_seq_len, then mask + k_gathered = k_pages[ + :, :max_seq_len, :, : + ] # [batch_size, max_seq_len, num_heads, head_dim] + v_gathered = v_pages[ + :, :max_seq_len, :, : + ] # [batch_size, max_seq_len, num_heads, head_dim] + + # Flatten and filter by mask + k_gathered_flat = k_gathered.reshape( + -1, num_heads, head_dim + ) # [batch_size * max_seq_len, num_heads, head_dim] + v_gathered_flat = v_gathered.reshape(-1, num_heads, head_dim) + token_mask_flat = token_mask.reshape(-1) # [batch_size * max_seq_len] + + # Keep only valid tokens + k_flat = k_gathered_flat[token_mask_flat] + v_flat = v_gathered_flat[token_mask_flat] + kv_indptr_tokens = torch.cat( [ torch.tensor([0], dtype=torch.int32, device=device), From eccbdde95558d0487cf50d4dba01b2e2091c2f8d Mon Sep 17 00:00:00 2001 From: Wenxuan Tan Date: Tue, 11 Nov 2025 21:40:04 -0600 Subject: [PATCH 049/130] minor: canonicalize TFLOPS calculation (#2069) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Refactor** * TFLOPS computation standardized across attention benchmarks so reported performance metrics consistently account for actual sequence and batch lengths. * **Bug Fixes** * Added checks to prevent invalid mixed-length causal inputs, avoiding misleading benchmark results. * **Chores** * Renamed timing parameter in the benchmark utility for clearer intent. --- benchmarks/bench_blackwell_attention.py | 22 +++++---- benchmarks/bench_block_sparse_attention.py | 15 ++++++- benchmarks/bench_hopper_attention.py | 52 +++++++++++++--------- benchmarks/bench_hopper_fp8_attention.py | 18 +++++--- flashinfer/testing/utils.py | 19 ++++++-- 5 files changed, 87 insertions(+), 39 deletions(-) diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py index 52452e05a8..73b0cd0b3c 100644 --- a/benchmarks/bench_blackwell_attention.py +++ b/benchmarks/bench_blackwell_attention.py @@ -18,7 +18,10 @@ import torch import flashinfer -from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing.utils import ( + bench_gpu_time, + attention_tflops_per_sec_with_actual_seq_lens, +) def bench_fmha_blackwell( @@ -69,14 +72,17 @@ def bench_fmha_blackwell( ) ms = np.median(measurements) - def flops(ms): - if causal: - return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 - else: - return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 - + TFLOPS = attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), qkv_len), + torch.full((batch_size,), qkv_len), + head_dim, + head_dim, + num_heads, + causal, + ms, + ) print( - f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s" + f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s" ) diff --git a/benchmarks/bench_block_sparse_attention.py b/benchmarks/bench_block_sparse_attention.py index e2a51012f5..2da2478a6f 100644 --- a/benchmarks/bench_block_sparse_attention.py +++ b/benchmarks/bench_block_sparse_attention.py @@ -18,7 +18,10 @@ import torch import flashinfer -from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing.utils import ( + bench_gpu_time, + attention_tflops_per_sec_with_actual_seq_lens, +) def bench_variable_block_sparse_attention( @@ -120,7 +123,15 @@ def bench_variable_block_sparse_attention( ) def flops(ms): - return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + return attention_tflops_per_sec_with_actual_seq_lens( + torch.tensor([seq_len]), + torch.tensor([seq_len]), + head_dim, + head_dim, + num_qo_heads, + False, + ms, + ) print( f"bench_variable_block_sparse_attention (num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, seq_len={seq_len}, num_blocks_row={num_blocks_row}, num_blocks_col={num_blocks_col}, block_density={block_density}), sparse fa2-template: {flops(sparse_ms_fa2):.3f} TFLOPs/s, sparse fa3-template: {flops(sparse_ms_fa3):.3f} TFLOPs/s, dense fa2-template: {flops(dense_sm80_ms):.3f} TFLOPs/s, dense fa3-template: {flops(dense_sm90_ms):.3f} TFLOPs/s" diff --git a/benchmarks/bench_hopper_attention.py b/benchmarks/bench_hopper_attention.py index 6ad2fdaa1b..c1e56e6225 100644 --- a/benchmarks/bench_hopper_attention.py +++ b/benchmarks/bench_hopper_attention.py @@ -18,7 +18,10 @@ import torch import flashinfer -from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing.utils import ( + bench_gpu_time, + attention_tflops_per_sec_with_actual_seq_lens, +) def bench_single_prefill(seq_len, num_heads, causal, head_dim): @@ -41,10 +44,15 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim): ) def flops(ms): - if causal: - return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 - else: - return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + return attention_tflops_per_sec_with_actual_seq_lens( + torch.tensor([seq_len]), + torch.tensor([seq_len]), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) print( f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" @@ -97,14 +105,15 @@ def bench_batch_ragged_prefill(batch_size, num_heads, seq_len, causal, head_dim) ) def flops(ms): - if causal: - return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 - ) - else: - return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 - ) + return attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), seq_len), + torch.full((batch_size,), seq_len), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) print( f"bench_batch_ragged_prefill (batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" @@ -176,14 +185,15 @@ def bench_batch_paged_prefill( ) def flops(ms): - if causal: - return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 - ) - else: - return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 - ) + return attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), seq_len), + torch.full((batch_size,), seq_len), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) print( f"bench_batch_paged_prefill (page_size={page_size} batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" diff --git a/benchmarks/bench_hopper_fp8_attention.py b/benchmarks/bench_hopper_fp8_attention.py index 34d71d7f9e..89224af622 100644 --- a/benchmarks/bench_hopper_fp8_attention.py +++ b/benchmarks/bench_hopper_fp8_attention.py @@ -2,7 +2,10 @@ import torch import flashinfer -from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing.utils import ( + bench_gpu_time, + attention_tflops_per_sec_with_actual_seq_lens, +) def bench_single_prefill(seq_len, num_heads, causal, head_dim): @@ -45,10 +48,15 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim): ) def flops(ms): - if causal: - return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 - else: - return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + return attention_tflops_per_sec_with_actual_seq_lens( + torch.tensor([seq_len]), + torch.tensor([seq_len]), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) print( f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s" diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index 46ede8de2b..6bb979c382 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -277,6 +277,12 @@ def attention_flops( Returns: total_flops (int): Total FLOPs for the layer. """ + # Causal attention requires kv_len >= q_len + if qo_seqlen > kv_seqlen: + raise ValueError( + "qo_seqlen must be less than or equal to kv_seqlen for causal attention" + ) + if causal: bmm1_flops = ( batch_size @@ -323,6 +329,13 @@ def attention_flops_with_actual_seq_lens( Returns: total_flops (int): Total FLOPs for the layer. """ + # Causal attention requires kv_len >= q_len + # Otherwise right align if kv_len > q_len + if causal and (actual_seq_lens_q > actual_seq_lens_kv).any(): + raise ValueError( + "actual_seq_lens_q must be less than or equal to actual_seq_lens_kv for causal attention" + ) + if causal: bmm1_flops = ( torch.dot( @@ -412,7 +425,7 @@ def attention_tflops_per_sec_with_actual_seq_lens( head_dim_vo, num_qo_heads, causal, - time, + ms, ): """ Calculate TFLOPS per second for a given attention layer with actual sequence lengths. @@ -425,7 +438,7 @@ def attention_tflops_per_sec_with_actual_seq_lens( head_dim_vo (int): Head dimension of the value. num_qo_heads (int): Number of query heads. causal (bool): Whether to use causal masking. - time (float): Execution time in milliseconds. + ms (float): Execution time in milliseconds. Returns: tflops_per_sec (float): TFLOPS per second for the layer. @@ -438,7 +451,7 @@ def attention_tflops_per_sec_with_actual_seq_lens( num_qo_heads, causal, ) - return f.item() / time / 1e9 if not math.isnan(time) else 0.0 + return f.item() / ms / 1e9 if not math.isnan(ms) else 0.0 def attention_tb_per_sec( From 96e73b80a11ed095b4162e0279e554a0cf7986ee Mon Sep 17 00:00:00 2001 From: dongjiyingdjy <87510204+dongjiyingdjy@users.noreply.github.com> Date: Wed, 12 Nov 2025 13:29:05 +0800 Subject: [PATCH 050/130] fix: fix test_trtllm_gen_attention when max_seq_len < page_size (#2076) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Adjusted attention test calculations for K/V cache sizing to use per-sequence page allocation before scaling to the batch, improving alignment with expected memory allocation. * This refines test expectations around cache sizing without changing validation logic, reducing false positives in memory-related test scenarios. Signed-off-by: Jiying Dong <87510204+dongjiyingdjy@users.noreply.github.com> --- tests/attention/test_trtllm_gen_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 4d1fe2891c..d279144195 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -102,8 +102,8 @@ def create_kv_cache( ): # Create separate K and V caches max_seq_len = torch.max(seq_lens).item() - num_tokens = max_seq_len * batch_size - num_pages = (num_tokens + page_size - 1) // page_size + num_pages_per_seq = (max_seq_len + page_size - 1) // page_size + num_pages = num_pages_per_seq * batch_size ref_kv_dtype_torch = DTYPE_MAP[ref_kv_dtype] if kv_dtype != "fp8": # for fp8, create with high precision to generate scale. assert kv_dtype == ref_kv_dtype, ( From 53a6da4788ffa794c43637797f90116d1f29a37f Mon Sep 17 00:00:00 2001 From: qsang-nv <200703406+qsang-nv@users.noreply.github.com> Date: Wed, 12 Nov 2025 22:25:18 +0800 Subject: [PATCH 051/130] enable xqa fp8 output (#2081) --- csrc/flashinfer_xqa_binding.cu | 5 +--- csrc/xqa/mha.cu | 10 ++++---- csrc/xqa/mha.h | 8 +++---- csrc/xqa/mha_sm90.cu | 8 +++---- csrc/xqa/xqa_wrapper.cu | 7 ++---- flashinfer/aot.py | 1 + flashinfer/decode.py | 12 +++++++++- flashinfer/jit/xqa.py | 10 ++++++-- flashinfer/xqa.py | 23 ++++++++++++++++--- tests/attention/test_trtllm_gen_attention.py | 6 +++++ tests/attention/test_xqa.py | 24 ++++++++++++++++---- tests/attention/test_xqa_batch_decode.py | 7 ++++-- 12 files changed, 86 insertions(+), 35 deletions(-) diff --git a/csrc/flashinfer_xqa_binding.cu b/csrc/flashinfer_xqa_binding.cu index 8556fb5e48..e21eb3a73d 100644 --- a/csrc/flashinfer_xqa_binding.cu +++ b/csrc/flashinfer_xqa_binding.cu @@ -27,10 +27,7 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla); #else void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, - int64_t slidingWinSize, double qScale, TensorView output, -#if LOW_PREC_OUTPUT - TensorView rcpOutScale, -#endif + int64_t slidingWinSize, double qScale, TensorView output, double rcpOutScale, TensorView q, tvm::ffi::Optional attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index 715267bedc..9359eb5d12 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -1281,7 +1281,7 @@ CUBIN_EXPORT __global__ float qScale, OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif // NOTE: the input is actually Q buffer when integrated to TRT-LLM. IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], @@ -2165,7 +2165,7 @@ CUBIN_EXPORT __global__ } ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum); #if LOW_PREC_OUTPUT - voScale *= rcpOutScale[0]; + voScale *= rcpOutScale; #endif rescaleAcc(warp, acc, fullRescaleMask, rcpRowSum * ThrdRegRowMax::filled(voScale)); } @@ -2396,7 +2396,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( float qScale, OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], #if SPEC_DEC @@ -2447,7 +2447,7 @@ void launchMHA( #endif float qScale, OutputHead* output, #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif #if USE_INPUT_KV InputHead const* qkv, @@ -2563,7 +2563,7 @@ static uint32_t const hostSmemSize = configureKernel(); void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, float qScale, OutputHead* output, #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, diff --git a/csrc/xqa/mha.h b/csrc/xqa/mha.h index ee4584ee84..872cd45059 100644 --- a/csrc/xqa/mha.h +++ b/csrc/xqa/mha.h @@ -95,7 +95,7 @@ void launchMHA( #endif float qScale, OutputHead* output, #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif #if USE_INPUT_KV InputHead const* qkv, @@ -125,7 +125,7 @@ void launchMHA( void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, float qScale, OutputHead* output, #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, @@ -145,7 +145,7 @@ void launchHopperF8MHA( #endif float qScale, OutputHead* output, #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif #if USE_INPUT_KV InputHead const* qkv, @@ -174,7 +174,7 @@ void launchHopperF8MHA( void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, float qScale, OutputHead* output, #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu index d0de67c372..9b751817c5 100644 --- a/csrc/xqa/mha_sm90.cu +++ b/csrc/xqa/mha_sm90.cu @@ -610,7 +610,7 @@ __launch_bounds__(128 * 3) float const qScale, OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] #if LOW_PREC_OUTPUT - float const* const rcpOutScale, + float rcpOutScale, #endif #if USE_INPUT_KV IOHead const* __restrict__ const qkv, // [nbReq][beamWidth][nbQHeads+nbKHeads+nbVHeads], @@ -957,7 +957,7 @@ __launch_bounds__(128 * 3) constexpr float xScale = 1.f / kE4M3_MAX; #if LOW_PREC_OUTPUT - float const oScale = rcpOutScale[0]; + float const oScale = rcpOutScale; #else constexpr float oScale = 1.F; #endif @@ -2910,7 +2910,7 @@ void launchHopperF8MHA( #endif float qScale, OutputHead* output, #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif #if USE_INPUT_KV InputHead const* qkv, @@ -3037,7 +3037,7 @@ static uint32_t const hostSmemSize = configureKernel(); void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, float qScale, OutputHead* output, #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float rcpOutScale, #endif InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index 1ac25fcf91..796a4b33ef 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -45,10 +45,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp #else void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, - int64_t slidingWinSize, double qScale, TensorView output, -#if LOW_PREC_OUTPUT - TensorView rcpOutScale, -#endif + int64_t slidingWinSize, double qScale, TensorView output, double rcpOutScale, TensorView q, Optional attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, @@ -70,7 +67,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, reinterpret_cast(output.data_ptr()), #if LOW_PREC_OUTPUT - reinterpret_cast(rcpOutScale.data_ptr()), + rcpOutScale, #endif reinterpret_cast(q.data_ptr()), attentionSinksPtr, reinterpret_cast(kCacheVLLM.data_ptr()), diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 58d55264d0..5801cc933b 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -404,6 +404,7 @@ def gen_xqa( head_dim=head_size, head_group_ratio=head_grp_size, use_sliding_window=use_sliding_window, + output_dtype=input_type, ) if has_sm120 or has_sm121: diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 5db7d95a51..574f8a024c 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2077,6 +2077,7 @@ def trtllm_batch_decode_with_kv_cache( enable_pdl: Optional[bool] = None, backend: str = "auto", q_len_per_req: Optional[int] = 1, + o_scale: Optional[float] = 1.0, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters @@ -2142,6 +2143,9 @@ def trtllm_batch_decode_with_kv_cache( For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. For sm_90 (hopper architecture) and sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. + o_scale : Optional[float] = 1.0 + output scale factor for xqa fp8 output. + Returns ------- out : Union[torch.Tensor, FP4Tensor] @@ -2196,6 +2200,7 @@ def trtllm_batch_decode_with_kv_cache( kv_layout=kv_layout, enable_pdl=enable_pdl, q_len_per_req=q_len_per_req, + o_scale=o_scale, ) elif backend == "trtllm-gen": # Convert NHD layout to HND if necessary (transpose only changes stride, not data) @@ -2340,6 +2345,7 @@ def xqa_batch_decode_with_kv_cache( kv_layout: str = "NHD", enable_pdl: bool = None, q_len_per_req: Optional[int] = 1, + o_scale: Optional[float] = 1.0, ) -> torch.Tensor: """ Parameters @@ -2388,6 +2394,9 @@ def xqa_batch_decode_with_kv_cache( Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode. + o_scale : Optional[float] = 1.0 + output scale factor for fp8 output. + Returns ------- out : torch.Tensor @@ -2434,7 +2443,7 @@ def xqa_batch_decode_with_kv_cache( workspace_u8 = workspace_buffer.view(torch.uint8) semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore scratch = workspace_u8[8 * 1024 * 1024 :] - kv_scale_value = bmm2_scale + kv_scale_value = bmm2_scale * o_scale q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5) query_new = query.unsqueeze(1) @@ -2464,6 +2473,7 @@ def xqa_batch_decode_with_kv_cache( kv_layout=kv_layout, sm_count=sm_count, enable_pdl=enable_pdl, + rcp_out_scale=1.0 / o_scale, ) return out diff --git a/flashinfer/jit/xqa.py b/flashinfer/jit/xqa.py index 5768236c73..04ab098be2 100644 --- a/flashinfer/jit/xqa.py +++ b/flashinfer/jit/xqa.py @@ -28,7 +28,6 @@ "-DBEAM_WIDTH=1", "-DUSE_INPUT_KV=0", "-DUSE_CUSTOM_BARRIER=1", - "-DLOW_PREC_OUTPUT=0", "-DSPEC_DEC=0", ] @@ -40,6 +39,7 @@ def gen_xqa_module( head_dim: int, head_group_ratio: int, use_sliding_window: bool, + output_dtype: torch.dtype, ) -> JitSpec: if input_dtype == torch.float16: flag_input_dtype = ["-DINPUT_FP16=1", "-DDTYPE=__half"] @@ -76,6 +76,11 @@ def gen_xqa_module( else: flag_sliding_window = ["-DSLIDING_WINDOW=0"] + if output_dtype == torch.float8_e4m3fn: + flag_low_prec_output = ["-DLOW_PREC_OUTPUT=1"] + else: + flag_low_prec_output = ["-DLOW_PREC_OUTPUT=0"] + compilation_context = CompilationContext() nvcc_flags = compilation_context.get_nvcc_flags_list( supported_major_versions=[9, 10, 11, 12] @@ -85,7 +90,7 @@ def gen_xqa_module( flag_mla_wrapper = ["-DMLA_WRAPPER=0"] return gen_jit_spec( - f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}", + f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}", [ jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu", jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu", @@ -101,6 +106,7 @@ def gen_xqa_module( + flag_kv_cache_dtype + flag_head_group_ratio + flag_sliding_window + + flag_low_prec_output + flag_mla_wrapper, extra_ldflags=["-lcuda"], # Add CUDA Driver API library ) diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index fd75e34f87..dbf80e7b11 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -38,6 +38,7 @@ def get_xqa_module( head_dim: int, head_group_ratio: int, use_sliding_window: bool, + output_dtype: torch.dtype, ): module = gen_xqa_module( input_dtype, @@ -46,10 +47,11 @@ def get_xqa_module( head_dim, head_group_ratio, use_sliding_window, + output_dtype, ).build_and_load() @register_custom_op( - f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}", + f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}", mutates_args=("output", "workspace_buffer"), ) def xqa( @@ -59,6 +61,7 @@ def xqa( sliding_win_size: int, q_scale: float, output: torch.Tensor, + rcp_out_scale: float, q: torch.Tensor, sinks: Optional[torch.Tensor], k_cache: torch.Tensor, @@ -79,6 +82,7 @@ def xqa( sliding_win_size, q_scale, output, + rcp_out_scale, q, sinks, k_cache, @@ -94,7 +98,7 @@ def xqa( ) @register_fake_op( - f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}" + f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}" ) def _fake_xqa( run_sm90_fp8_mha: bool, @@ -103,6 +107,7 @@ def _fake_xqa( sliding_win_size: int, q_scale: float, output: torch.Tensor, + rcp_out_scale: float, q: torch.Tensor, sinks: Optional[torch.Tensor], k_cache: torch.Tensor, @@ -140,6 +145,7 @@ def xqa( kv_layout: str = "NHD", sm_count: Optional[int] = None, enable_pdl: Optional[bool] = None, + rcp_out_scale: float = 1.0, ) -> None: r"""Apply attention with paged KV cache using XQA kernel. Parameters @@ -167,7 +173,7 @@ def xqa( Data type should be torch.uint32. output : torch.Tensor Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``. - Data type should match query tensor. This tensor will be modified in-place. + Data type should match query tensor or kv tensor. This tensor will be modified in-place. workspace_buffer : torch.Tensor Workspace buffer for temporary computations. Data type should be torch.uint8. @@ -196,6 +202,8 @@ def xqa( enable_pdl : Optional[bool], default=None Whether to enable PDL (Persistent Data Loader) optimization. If None, will be set to True if hardware supports it. + rcp_out_scale : float, default=1.0 + Reciprocal of output scale factor. Note ---- @@ -231,6 +239,13 @@ def xqa( assert k_cache.dtype == v_cache.dtype, "K and V cache must have the same dtype" + if output.dtype == torch.float8_e4m3fn: + assert k_cache.dtype == torch.float8_e4m3fn, ( + "KV cache must be fp8 when output is fp8" + ) + else: + assert output.dtype == q.dtype, "Output and query must have the same dtype" + # Convert HND layout to NHD if necessary (transpose only changes stride, not data) if kv_layout == "HND": # For HND: [..., H, N, D] -> NHD: [..., N, H, D] @@ -255,6 +270,7 @@ def xqa( head_dim, head_group_ratio, use_sliding_window, + output.dtype, ) xqa_module.xqa( run_sm90_fp8_mha, @@ -263,6 +279,7 @@ def xqa( sliding_win_size if use_sliding_window else 0, q_scale, output, + rcp_out_scale, q, sinks, k_cache, diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index d279144195..0d80e9cf90 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -665,6 +665,9 @@ def _test_trtllm_batch_decode( # todo(Yingyi): add support for nvfp4 with speculative decoding pytest.skip("nvfp4 is not supported for q_len_per_req > 1") + if backend == "trtllm-gen" and o_dtype == "fp8" and q_dtype != "fp8": + pytest.skip("trtllm-gen backend only supports fp8 output for fp8 query") + # Set up test parameters torch.manual_seed(0) @@ -797,6 +800,7 @@ def _test_trtllm_batch_decode( enable_pdl=enable_pdl, backend=backend, q_len_per_req=q_len_per_req, + o_scale=o_scale, ) if backend == "trtllm-gen": # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero @@ -926,6 +930,8 @@ def _test_trtllm_batch_decode( ("fp16", "fp16", "fp16"), ("bf16", "fp8", "bf16"), ("fp16", "fp8", "fp16"), + ("bf16", "fp8", "fp8"), + ("fp16", "fp8", "fp8"), ("fp8", "fp8", "bf16"), ("fp8", "fp8", "fp16"), ("fp8", "fp8", "fp8"), diff --git a/tests/attention/test_xqa.py b/tests/attention/test_xqa.py index 172135c571..b6454de05a 100644 --- a/tests/attention/test_xqa.py +++ b/tests/attention/test_xqa.py @@ -137,7 +137,6 @@ def ref_attention( @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("use_sliding_window", [True, False]) @pytest.mark.parametrize("input_type", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("fp8_kv_cache", [True, False]) @pytest.mark.parametrize("use_attention_sinks", [True, False]) @pytest.mark.parametrize("seq_len", [2, 15, 256, 514]) @pytest.mark.parametrize("batch_size", [1, 4]) @@ -146,8 +145,17 @@ def ref_attention( @pytest.mark.parametrize("valid_elems_per_head", [32, 128]) @pytest.mark.parametrize("head_grp_size", [8, 16]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) -@pytest.mark.parametrize("kv_scale", [1.0, 0.5]) @pytest.mark.parametrize("q_scale", [1.0, 0.5]) +@pytest.mark.parametrize( + "fp8_kv_cache,kv_scale,use_fp8_output", + [ + (False, 1.0, False), # Non-FP8 KV cache: kv_scale=1.0, no FP8 output + (True, 1.0, False), # FP8 KV cache: kv_scale=1.0, no FP8 output + (True, 1.0, True), # FP8 KV cache: kv_scale=1.0, with FP8 output + (True, 0.5, False), # FP8 KV cache: kv_scale=0.5, no FP8 output + (True, 0.5, True), # FP8 KV cache: kv_scale=0.5, with FP8 output + ], +) def test_xqa( batch_size, nb_k_heads, @@ -163,9 +171,8 @@ def test_xqa( kv_layout, kv_scale, q_scale, + use_fp8_output, ): - if kv_scale != 1.0 and fp8_kv_cache is False: - pytest.skip("kv cache scale works only for fp8 kv cache") set_random_seed(42) nb_q_heads = nb_k_heads * head_grp_size @@ -175,7 +182,7 @@ def test_xqa( beam_width, nb_q_heads, valid_elems_per_head, - dtype=input_type, + dtype=torch.float8_e4m3fn if use_fp8_output else input_type, device="cuda", ) output.fill_(float("nan")) @@ -326,6 +333,8 @@ def test_xqa( scratch_size = 256 << 20 scratch_buf = torch.zeros(scratch_size, dtype=torch.uint8, device="cuda") + rcp_out_scale = 4.0 if use_fp8_output else 1.0 + xqa( q_heads, cache_k_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_k_heads, @@ -344,6 +353,7 @@ def test_xqa( kv_layout=kv_layout, sm_count=sm_count, enable_pdl=enable_pdl, + rcp_out_scale=rcp_out_scale, ) for req in range(batch_size): @@ -398,6 +408,10 @@ def test_xqa( else: atol = 0.01 rtol = 0.01 + if use_fp8_output: + ref_output = ref_output * rcp_out_scale + atol = 0.15 + rtol = 0.15 diff_abs = torch.abs(ref_output - kernel_output) diff_rel = diff_abs / (torch.abs(ref_output) + 1e-8) diff --git a/tests/attention/test_xqa_batch_decode.py b/tests/attention/test_xqa_batch_decode.py index 7a2bd3356a..a360545041 100644 --- a/tests/attention/test_xqa_batch_decode.py +++ b/tests/attention/test_xqa_batch_decode.py @@ -311,6 +311,8 @@ def get_last_page_len(seq_lens, page_size): ("fp16", "fp16", "fp16"), ("bf16", "fp8", "bf16"), ("fp16", "fp8", "fp16"), + ("bf16", "fp8", "fp8"), + ("fp16", "fp8", "fp8"), ], ) @pytest.mark.parametrize("enable_pdl", [True, False, None]) @@ -458,12 +460,13 @@ def test_xqa_batch_decode( sinks=(sink if enable_sink else None), kv_layout=kv_layout, q_len_per_req=q_len_per_req, + o_scale=o_scale, ) # Verification torch.testing.assert_close( - output, - output_ref, + output.float(), + output_ref.float() / o_scale, rtol=1e-1 if kv_dtype == "fp8" else 1e-2, atol=1e-1 if kv_dtype == "fp8" else 1e-2, ) From abf6a14e836fd26d25211fb2a98f2b6d9eaaf3c5 Mon Sep 17 00:00:00 2001 From: Raayan Dhar <58057652+raayandhar@users.noreply.github.com> Date: Wed, 12 Nov 2025 06:25:58 -0800 Subject: [PATCH 052/130] chore: update requires-python in pyproject.toml (#2080) --- benchmarks/bench_batch_attention.py | 2 +- benchmarks/bench_mixed_attention.py | 2 +- docs/installation.rst | 2 +- flashinfer/autotuner.py | 4 +++- flashinfer/cascade.py | 4 +++- flashinfer/comm/mnnvl.py | 2 +- flashinfer/cute_dsl/blockscaled_gemm.py | 2 ++ flashinfer/jit/attention/utils.py | 17 +++++++++++++---- pyproject.toml | 2 +- 9 files changed, 26 insertions(+), 11 deletions(-) diff --git a/benchmarks/bench_batch_attention.py b/benchmarks/bench_batch_attention.py index 2c1071d808..c94a86eacc 100644 --- a/benchmarks/bench_batch_attention.py +++ b/benchmarks/bench_batch_attention.py @@ -436,7 +436,7 @@ def main(args: argparse.Namespace) -> None: records_new = [] records_separate = [] for cfg_id, (decode_case, prefill_case) in enumerate( - zip(decode_lens, prefill_lens), start=1 + zip(decode_lens, prefill_lens, strict=True), start=1 ): prefill_kv_lens = [p[0] for p in prefill_case] prefill_qo_lens = [p[1] for p in prefill_case] diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index 9773e8f37d..9bb6616737 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -218,7 +218,7 @@ def _run_single_prefill(): head_dim = 128 for idx, (p_q_lens, p_kv_lens, d_q_len, d_kv_len) in enumerate( - zip(p_q_configs, p_kv_configs, d_q_len_configs, d_kv_len_configs) + zip(p_q_configs, p_kv_configs, d_q_len_configs, d_kv_len_configs, strict=True) ): print(f"===== Benchmark {idx + 1}: (kv_len, qo_len) set =====") run_bench( diff --git a/docs/installation.rst b/docs/installation.rst index 7550a73622..4f628f7094 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -13,7 +13,7 @@ Prerequisites - OS: Linux only -- Python: 3.9, 3.10, 3.11, 3.12, 3.13 +- Python: 3.10, 3.11, 3.12, 3.13, 3.14 Quick Start ^^^^^^^^^^^ diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index a82fabd8c0..6b6a0c3e48 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -648,7 +648,9 @@ def _generate_optimization_profiles( opt_shapes_max = { v1: v2 - for v1, v2 in zip(opt_shapes, tuple(opt_shapes[1:]) + (float("inf"),)) + for v1, v2 in zip( + opt_shapes, tuple(opt_shapes[1:]) + (float("inf"),), strict=True + ) } dynamic_dims.append( (spec.input_idx, spec.dim_idx, opt_shapes_max, opt_shapes) diff --git a/flashinfer/cascade.py b/flashinfer/cascade.py index 5281672b39..267f0d2990 100644 --- a/flashinfer/cascade.py +++ b/flashinfer/cascade.py @@ -349,6 +349,7 @@ def __init__( paged_kv_indptr_buf_arr, paged_kv_indices_buf_arr, paged_kv_last_page_len_buf_arr, + strict=True, ) ] else: @@ -381,7 +382,7 @@ def reset_workspace_buffer( be the same as the device of the input tensors. """ for wrapper, int_workspace_buffer in zip( - self._batch_prefill_wrappers, int_workspace_buffers + self._batch_prefill_wrappers, int_workspace_buffers, strict=True ): wrapper.reset_workspace_buffer(float_workspace_buffer, int_workspace_buffer) @@ -479,6 +480,7 @@ def plan( paged_kv_indptr_arr, paged_kv_indices_arr, paged_kv_last_page_len, + strict=True, ) ): wrapper.plan( diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 2e98efd1e2..12aec978ec 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -414,7 +414,7 @@ def open_mnnvl_memory(mapping: Mapping, size: int): pidfds.append(pidfd) remote_fds = [] - for pidfd, fd in zip(pidfds, all_handles_data): + for pidfd, fd in zip(pidfds, all_handles_data, strict=True): remote_fd = syscall(SYS_pidfd_getfd, pidfd, fd, 0) if remote_fd < 0: err = ctypes.get_errno() diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/cute_dsl/blockscaled_gemm.py index d69eda2743..7c4ecc1fc7 100644 --- a/flashinfer/cute_dsl/blockscaled_gemm.py +++ b/flashinfer/cute_dsl/blockscaled_gemm.py @@ -154,6 +154,7 @@ def __new_from_mlir_values__(self, values): self._cluster_shape_mnk, ], self._values_pos, + strict=True, ): obj_list.append(new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] @@ -348,6 +349,7 @@ def _get_current_work_for_linear_idx( cur_cluster_coord, self.cta_id_in_cluster, (*self.params.cluster_shape_mn, Int32(1)), + strict=True, ) ) diff --git a/flashinfer/jit/attention/utils.py b/flashinfer/jit/attention/utils.py index ac033a65b8..2f2e030b12 100644 --- a/flashinfer/jit/attention/utils.py +++ b/flashinfer/jit/attention/utils.py @@ -30,11 +30,14 @@ def generate_additional_params( for dtype, var in zip( additional_tensor_dtypes, additional_tensor_names, + strict=True, ) ] + [ f"{dtype} {var};\n" - for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names) + for dtype, var in zip( + additional_scalar_dtypes, additional_scalar_names, strict=True + ) ] ) additional_func_params = "".join( @@ -48,7 +51,9 @@ def generate_additional_params( ] + [ f", {dtype} {var}" - for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names) + for dtype, var in zip( + additional_scalar_dtypes, additional_scalar_names, strict=True + ) ] ) if is_sm90_template: @@ -59,7 +64,9 @@ def generate_additional_params( if var.startswith("maybe") else f"params.additional_params.{var} = static_cast<{dtype}*>({var}.data_ptr());" ) - for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names) + for dtype, var in zip( + additional_tensor_dtypes, additional_tensor_names, strict=True + ) ] + [ f"params.additional_params.{var} = {var};" @@ -74,7 +81,9 @@ def generate_additional_params( if var.startswith("maybe") else f"params.{var} = static_cast<{dtype}*>({var}.data_ptr());" ) - for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names) + for dtype, var in zip( + additional_tensor_dtypes, additional_tensor_names, strict=True + ) ] + [f"params.{var} = {var};" for var in additional_scalar_names] ) diff --git a/pyproject.toml b/pyproject.toml index 57a966c04d..eafa93ba0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ [project] name = "flashinfer-python" description = "FlashInfer: Kernel Library for LLM Serving" -requires-python = ">=3.9,<4.0" +requires-python = ">=3.10,<4.0" authors = [{ name = "FlashInfer team" }] license = "Apache-2.0" readme = "README.md" From 6765cadd14fbedc9ffab428a87149a7d3f5d69f1 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 12 Nov 2025 06:26:36 -0800 Subject: [PATCH 053/130] [Test] Optimize test_trtllm_gen_fused_moe.py (#2072) --- tests/moe/test_trtllm_gen_fused_moe.py | 68 +++++++++++++++++--------- 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 4706b4c87a..747946fc09 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -49,6 +49,10 @@ from flashinfer.utils import get_compute_capability +# Max num tokens to tune for trtllm-gen fused moe +TUNE_MAX_NUM_TOKENS = 4096 + + def check_cuda(err): """Unified CUDA error checking function used throughout the file.""" if err != runtime.cudaError_t.cudaSuccess: @@ -76,6 +80,7 @@ def __init__(self, moe_impl, static_data, **config): self.moe_impl = moe_impl self.static_data = static_data self.config = config + self.enable_autotune = config.get("enable_autotune", True) self.graph = None self.graph_exec = None self.stream = None @@ -106,7 +111,7 @@ def capture(self, hidden_states_sample, **runtime_args): self.input_tensor = hidden_states_sample.clone() # Warmup - with torch.cuda.stream(torch_stream), autotune(True): + with torch.cuda.stream(torch_stream), autotune(self.enable_autotune): for _ in range(1): self._run_moe_computation(runtime_args) @@ -207,6 +212,7 @@ def _run_moe_computation(self, runtime_args): routing_method_type=self.config["routing_method_type"], gated_act_type=self.config["gated_act_type"], do_finalize=True, + tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, ) return output # Extract tensor from tuple @@ -551,6 +557,7 @@ def call_moe( routed_scaling = kwargs["routed_scaling"] gated_act_type = kwargs["gated_act_type"] routing_method_type = kwargs["routing_method_type"] + enable_autotune = kwargs.get("enable_autotune", True) # Create CUDA graph configuration config = { @@ -563,6 +570,7 @@ def call_moe( "routed_scaling": routed_scaling, "gated_act_type": gated_act_type, "routing_method_type": routing_method_type, + "enable_autotune": enable_autotune, } runtime_args = { @@ -761,6 +769,7 @@ def call_moe( intermediate_size = kwargs["intermediate_size"] routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] + enable_autotune = kwargs.get("enable_autotune", True) enable_pdl = kwargs.get("enable_pdl") hidden_states_scale = kwargs["hidden_states_scale"] hidden_states_quant = kwargs["hidden_states_quant"] @@ -772,7 +781,7 @@ def call_moe( ) # Use autotuner for optimal kernel selection - with autotune(True): + with autotune(enable_autotune): output = trtllm_fp8_block_scale_moe( expert_logits, routing_bias, @@ -795,6 +804,7 @@ def call_moe( use_shuffled_weight=static_data["use_shuffled_weight"], weight_layout=static_data["weight_layout"], enable_pdl=enable_pdl, + tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, ) return output.to(torch.float) @@ -937,6 +947,7 @@ def call_moe( intermediate_size = kwargs["intermediate_size"] routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] + enable_autotune = kwargs.get("enable_autotune", True) # Quantize to FP8 per-tensor using pre-computed global scale factor hidden_states_fp8, _ = quant_fp8_per_tensor( @@ -944,7 +955,7 @@ def call_moe( ) # Use autotuner for optimal kernel selection - with autotune(True): + with autotune(enable_autotune): output = trtllm_fp8_per_tensor_scale_moe( ( expert_logits.to(torch.bfloat16) @@ -970,6 +981,7 @@ def call_moe( == RoutingMethodType.Llama4, # Use_routing_scales_on_input None, routing_method_type, + tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, ) return output.to(torch.float) @@ -1101,9 +1113,10 @@ def call_moe( top_k_groups = kwargs["top_k_groups"] intermediate_size = kwargs["intermediate_size"] routing_method_type = kwargs["routing_method_type"] + enable_autotune = kwargs.get("enable_autotune", True) # Use autotuner for optimal kernel selection - with autotune(True): + with autotune(enable_autotune): output = trtllm_bf16_moe( expert_logits, # float routing_bias, @@ -1120,6 +1133,7 @@ def call_moe( use_shuffled_weight=static_data["use_shuffled_weight"], weight_layout=static_data["weight_layout"], routing_method_type=routing_method_type, + tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, ) return output.to(torch.float) @@ -1408,20 +1422,18 @@ def routing_reference_topk(expert_logits, top_k, num_experts, padding): def check_accuracy(a, b, atol, rtol, percent): """Unified accuracy checking function with detailed error reporting.""" - if torch.any(torch.isnan(a)): - raise Exception("NaN in reference output") - if torch.any(torch.isnan(b)): - raise Exception("NaN in actual output") - if torch.any(torch.isinf(a)): - raise Exception("Inf in reference output") - if torch.any(torch.isinf(b)): - raise Exception("Inf in actual output") + if not torch.isfinite(a).all(): + raise Exception("Non-finite values in reference output") + if not torch.isfinite(b).all(): + raise Exception("Non-finite values in actual output") assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" - left = torch.abs(a - b) - right = atol + rtol * torch.abs(b) - count = torch.sum(left > right) - mismatch_percent = count / a.numel() + close = torch.isclose(a, b, atol=atol, rtol=rtol) + match_ratio = close.float().mean() + if match_ratio >= percent: + return + + mismatch_percent = 1.0 - match_ratio.item() if mismatch_percent > 1 - percent: raise Exception( f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " @@ -1999,6 +2011,7 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): "gated_act_type": args.gated_act_type, "hidden_states_scale": args.hidden_states_scale, "hidden_states_quant": kwargs["hidden_states_quant"], + "enable_autotune": kwargs.get("enable_autotune", True), } return moe_impl.call_moe( @@ -2238,6 +2251,8 @@ def run_moe_test( pytest.fail("Reference computation failed to produce output") # Compute actual output + enable_autotune = routing_config.get("enable_autotune", True) + output_dequant_actual = moe_impl.compute_production( args_dequant, args, @@ -2253,6 +2268,7 @@ def run_moe_test( weight_processing=weight_processing, enable_pdl=True, hidden_states_quant=inputs_data["hidden_states"], + enable_autotune=enable_autotune, ) # Compare outputs @@ -2267,7 +2283,7 @@ def run_moe_test( # Test: Renormalize routing -@pytest.mark.parametrize("num_tokens", [1, 8, 1024, 3072]) +@pytest.mark.parametrize("num_tokens", [8, 768, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [1024, 768, 512, 384]) @pytest.mark.parametrize( @@ -2301,8 +2317,9 @@ def run_moe_test( BF16Moe, ], "compatible_intermediate_size": [384, 768, 1024], + "enable_autotune": True, }, - id="Qwen3", + id="Qwen3_MOE", ), pytest.param( { @@ -2321,6 +2338,7 @@ def run_moe_test( BF16Moe, ], "compatible_intermediate_size": [384, 1024], + "enable_autotune": False, }, id="Renorm", ), @@ -2341,6 +2359,7 @@ def run_moe_test( BF16Moe, ], "compatible_intermediate_size": [512], + "enable_autotune": True, }, id="Qwen3_next", ), @@ -2406,7 +2425,7 @@ def test_renormalize_routing( # Test: DeepSeekV3 routing -@pytest.mark.parametrize("num_tokens", [1, 8, 1024, 3072]) +@pytest.mark.parametrize("num_tokens", [8, 768, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( @@ -2433,6 +2452,7 @@ def test_renormalize_routing( "routing_method_type": RoutingMethodType.DeepSeekV3, "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], "compatible_intermediate_size": [1024, 2048], + "enable_autotune": True, }, id="kimi_k2", ), @@ -2448,6 +2468,7 @@ def test_renormalize_routing( "routing_method_type": RoutingMethodType.DeepSeekV3, "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], "compatible_intermediate_size": [512, 1024, 2048], + "enable_autotune": True, }, id="DSv3", ), @@ -2463,6 +2484,7 @@ def test_renormalize_routing( "routing_method_type": RoutingMethodType.DeepSeekV3, "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], "compatible_intermediate_size": [384, 768], + "enable_autotune": False, }, id="DSLite", ), @@ -2528,7 +2550,7 @@ def test_deepseekv3_routing( # Test: TopK routing -@pytest.mark.parametrize("num_tokens", [1, 8, 128]) # Limited for GeGlu +@pytest.mark.parametrize("num_tokens", [8, 128]) # Limited for GeGlu @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [384, 512, 768, 1024]) @pytest.mark.parametrize( @@ -2552,7 +2574,8 @@ def test_deepseekv3_routing( "has_routing_bias": False, "routing_method_type": RoutingMethodType.TopK, "compatible_moe_impls": [FP4Moe], - "compatible_intermediate_size": [384, 512, 768, 1024], + "compatible_intermediate_size": [512, 768, 1024], + "enable_autotune": True, }, id="TopK", ), @@ -2602,7 +2625,7 @@ def test_topk_routing( # Test: Llama4 routing -@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("num_tokens", [8, 768, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [1024, 2048]) @pytest.mark.parametrize( @@ -2626,6 +2649,7 @@ def test_topk_routing( "routing_method_type": RoutingMethodType.Llama4, "compatible_moe_impls": [FP8PerTensorMoe], "compatible_intermediate_size": [1024, 2048], + "enable_autotune": True, }, id="Llama4", ), From b433fc729ecda5c010a807fc50ecd4f1a6ee6ad6 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Wed, 12 Nov 2025 20:17:10 -0800 Subject: [PATCH 054/130] test: Change incorrect inputs in test_hopper.py (#2083) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Brings in some changes to `test_hopper.py` to pass more unit tests * `test_deepseek_prefill` --> Raise tolerance for bf16 inputs * Others: The ``` token_pos_in_items_len=torch.tensor(token_pos_in_items_len) .to(dtype=torch.uint32) .to(0), ``` is an incorrect API and results in invalid input errors. Change it to: `token_pos_in_items_len=token_pos_in_items_len,` so that it matches the correct usage in e.g. [test_batch_prefill_kernels.py](https://github.com/flashinfer-ai/flashinfer/blob/6765cadd14fbedc9ffab428a87149a7d3f5d69f1/tests/attention/test_batch_prefill_kernels.py#L890) After this, `test_hopper.py` result improves to `3 failed, 2865 passed, 1320 skipped in 65.26s (0:01:05) ` ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes --- tests/attention/test_hopper.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/tests/attention/test_hopper.py b/tests/attention/test_hopper.py index f928e213bb..0a1b6fe8a7 100644 --- a/tests/attention/test_hopper.py +++ b/tests/attention/test_hopper.py @@ -194,8 +194,15 @@ def test_deepseek_prefill( ) o_sm90, lse_sm90 = wrapper_sm90.run_return_lse(q, k, v) - torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + if dtype == torch.half: + rtol = 1e-3 + atol = 1e-3 + else: # bfloat16 + rtol = 1e-2 + atol = 1e-2 + + torch.testing.assert_close(lse_sm80, lse_sm90, rtol=rtol, atol=atol) + torch.testing.assert_close(o_sm80, o_sm90, rtol=rtol, atol=atol) @pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) @@ -373,9 +380,7 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3( token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) .to(dtype=torch.uint16) .to(0), - token_pos_in_items_len=torch.tensor(token_pos_in_items_len) - .to(dtype=torch.uint32) - .to(0), + token_pos_in_items_len=token_pos_in_items_len, max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) o_fa2, lse_fa2 = wrapper_fa2.run_return_lse(q, kv_data) @@ -398,9 +403,7 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3( token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) .to(dtype=torch.uint16) .to(0), - token_pos_in_items_len=torch.tensor(token_pos_in_items_len) - .to(dtype=torch.uint32) - .to(0), + token_pos_in_items_len=token_pos_in_items_len, max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) @@ -507,9 +510,7 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3_bsz2( token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) .to(dtype=torch.uint16) .to(0), - token_pos_in_items_len=torch.tensor(token_pos_in_items_len) - .to(dtype=torch.uint32) - .to(0), + token_pos_in_items_len=token_pos_in_items_len, max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) o_fa2, lse_fa2 = wrapper_fa2.run_return_lse(q, kv_data) @@ -532,9 +533,7 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3_bsz2( token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) .to(dtype=torch.uint16) .to(0), - token_pos_in_items_len=torch.tensor(token_pos_in_items_len) - .to(dtype=torch.uint32) - .to(0), + token_pos_in_items_len=token_pos_in_items_len, max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) From 54101e9533bd65af6393e7f86c30c0bff9794499 Mon Sep 17 00:00:00 2001 From: Johnny Date: Thu, 13 Nov 2025 10:54:43 +0100 Subject: [PATCH 055/130] [NVIDIA] Thor & Spark Support (#2028) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Thor and Spark support when wheels are generating ## ๐Ÿ” Related Issues Output says that is not compatible. Only with JIT is working. ## Summary by CodeRabbit * **New Features** * Broadened GPU architecture support to include additional newer architectures. * **Documentation** * Updated README and installation docs to show the revised CUDA architecture example list. * **Chores** * Adjusted release/nightly workflows and build scripts to select architectures using an expanded CUDA-version threshold and branching logic. * **Performance** * Extended architecture-specific build/runtime handling to cover an additional GPU architecture affecting memory-related behavior. --------- Co-authored-by: Zihao Ye Co-authored-by: yzh119 --- .github/workflows/nightly-release.yml | 2 +- .github/workflows/release.yml | 2 +- README.md | 2 +- csrc/xqa/mha.cu | 2 +- csrc/xqa/utils.cuh | 3 ++- docs/installation.rst | 2 +- scripts/task_test_jit_cache_package_build_import.sh | 11 ++++++++++- 7 files changed, 17 insertions(+), 7 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 4d5acdfe63..2e7230cfa5 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -145,7 +145,7 @@ jobs: - name: Build wheel in container env: DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }} - FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }} + FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0f 12.0f' }} FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} run: | # Extract CUDA major and minor versions diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7e406ff2ac..0c95611c50 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -182,7 +182,7 @@ jobs: - name: Build wheel in container env: DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }} - FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }} + FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0f 12.0f' }} run: | # Extract CUDA major and minor versions CUDA_MAJOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f1) diff --git a/README.md b/README.md index 8f93c97f7a..88b579b180 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ python -m pip install dist/*.whl `flashinfer-jit-cache` (customize `FLASHINFER_CUDA_ARCH_LIST` for your target GPUs): ```bash -export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" +export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0f 12.0f" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index 9359eb5d12..016a4f982a 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -93,7 +93,7 @@ __constant__ constexpr uint32_t cacheVTileSeqLen = 32; constexpr uint32_t preferedKHeadPartBytes = 64; __constant__ constexpr uint32_t cacheVTileSeqLen = 32; #elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 900 || \ - __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030 + __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030 || __CUDA_ARCH__ == 1100 constexpr uint32_t preferedKHeadPartBytes = 128; __constant__ constexpr uint32_t cacheVTileSeqLen = 64; #else diff --git a/csrc/xqa/utils.cuh b/csrc/xqa/utils.cuh index f96d83f5f5..6302d4e20b 100644 --- a/csrc/xqa/utils.cuh +++ b/csrc/xqa/utils.cuh @@ -46,7 +46,8 @@ __constant__ constexpr float kE4M3_MAX = 448.F; constexpr uint32_t kMAX_SMEM_SIZE = (99u << 10); #elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 constexpr uint32_t kMAX_SMEM_SIZE = (163u << 10); -#elif __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030 +#elif __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030 || \ + __CUDA_ARCH__ == 1100 constexpr uint32_t kMAX_SMEM_SIZE = (227u << 10); #endif #endif diff --git a/docs/installation.rst b/docs/installation.rst index 4f628f7094..9087e87471 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -92,7 +92,7 @@ You can follow the steps below to install FlashInfer from source code: .. code-block:: bash - export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a" + export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0f 12.0f" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl diff --git a/scripts/task_test_jit_cache_package_build_import.sh b/scripts/task_test_jit_cache_package_build_import.sh index e2e4a824aa..d03937bc47 100755 --- a/scripts/task_test_jit_cache_package_build_import.sh +++ b/scripts/task_test_jit_cache_package_build_import.sh @@ -43,7 +43,16 @@ arches = ["7.5", "8.0", "8.9", "9.0a"] if cuda_ver is not None: try: major, minor = map(int, cuda_ver.split(".")[:2]) - if (major, minor) >= (12, 8): + if (major, minor) >= (13, 0): + arches.append("10.0a") + arches.append("10.3a") + arches.append("11.0f") + arches.append("12.0f") + elif (major, minor) >= (12, 9): + arches.append("10.0a") + arches.append("10.3a") + arches.append("12.0f") + elif (major, minor) >= (12, 8): arches.append("10.0a") arches.append("12.0a") except Exception: From 9a79b7868cce0bf521499c3a04a197bb6573468c Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Thu, 13 Nov 2025 17:35:16 -0800 Subject: [PATCH 056/130] [API change] deprecate tile_token_dim in trtllm_moe (#2086) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Deprecate `tile_token_dim` in trtllm_moe. It is already not used and mark with deprecation warning, plan to deprecate totally in next major release ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Refactor** * Removed the deprecated `tile_tokens_dim` parameter from MOE benchmarks and kernel functions, streamlining API calls and eliminating associated deprecation warnings. Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- benchmarks/README.md | 3 +- .../bench_trtllm_gen_fused_moe_autotuner.py | 3 -- .../routines/flashinfer_benchmark_utils.py | 1 - benchmarks/routines/moe.py | 37 ------------------- benchmarks/samples/sample_testlist_output.csv | 2 +- benchmarks/samples/sample_testlist_output.txt | 12 +++--- csrc/trtllm_fused_moe_kernel_launcher.cu | 4 -- flashinfer/fused_moe/core.py | 32 ---------------- tests/moe/test_trtllm_gen_fused_moe.py | 3 -- tests/moe/test_trtllm_gen_routed_fused_moe.py | 2 - 10 files changed, 8 insertions(+), 91 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index e7e17156a4..d81e9c3642 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -166,8 +166,7 @@ The output CSV will contain detailed metrics including: | `--topk_group` | Number of groups to consider for top-k routing. Default: 1 | | `--routed_scaling_factor`| Scaling factor for routing. Default: 2.5 | | `--local_expert_offset` | Offset of local experts in global expert space. Default: 0 | -| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts | -| `--tile_tokens_dim` | Tile dimension for tokens. Default: 8 | +| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts | | | `--routing_method` | Routing method: `renormalize`, `deepseek_v3`, `llama4`, `renormalize_naive`. Default: `deepseek_v3`. | | `--use_shuffled_weight` | Whether to use shuffled weight layout | | `--weight_layout` | Weight layout: 0=MajorK, 1=MajorMn, 2=BlockMajorK. Default: 0 | diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 0aff25860e..203faaff82 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -114,7 +114,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( 0, # local_expert_offset num_experts, 2.5, # routed_scaling_factor - None, # tile_tokens_dim RoutingMethodType.DeepSeekV3.value, True, # use_shuffled_weight WeightLayout.BlockMajorK.value, # weight_layout @@ -142,7 +141,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( num_experts, 1.0, # routed_scaling_factor False, # use_routing_scales_on_input - None, # tile_tokens_dim RoutingMethodType.TopK.value, enable_pdl, num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, @@ -287,7 +285,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( 0, # local_expert_offset num_experts, None, # routed_scaling_factor - None, # tile_tokens_dim RoutingMethodType.Renormalize.value, True, enable_pdl, diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 8798f8340f..520029f0ec 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -53,7 +53,6 @@ "routed_scaling_factor", "local_expert_offset", "local_num_experts", - "tile_tokens_dim", "routing_method", "use_shuffled_weight", "weight_layout", diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index 6af3425c73..8f26bdb8f7 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -116,13 +116,6 @@ def parse_moe_args(line, parser): default=None, help="Number of experts handled by this device. Defaults to num_experts.", ) - parser.add_argument( - "--tile_tokens_dim", - type=int, - required=False, - default=8, - help="Tile dimension for tokens.", - ) parser.add_argument( "--routing_method", type=str, @@ -560,7 +553,6 @@ def testTrtllmFp4BlockScaleMoe(args): ) local_expert_offset = args.local_expert_offset local_num_experts = args.local_num_experts or num_experts - tile_tokens_dim = args.tile_tokens_dim routing_method_type = args.routing_method_type use_shuffled_weight = args.use_shuffled_weight weight_layout = args.weight_layout @@ -705,7 +697,6 @@ def run_fp4_moe(): local_expert_offset=local_expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, - tile_tokens_dim=tile_tokens_dim, routing_method_type=routing_method_type, gated_act_type=gated_act_type, do_finalize=True, @@ -780,7 +771,6 @@ def run_fp4_moe(): cur_res["routed_scaling_factor"] = routed_scaling_factor cur_res["local_expert_offset"] = local_expert_offset cur_res["local_num_experts"] = local_num_experts - cur_res["tile_tokens_dim"] = tile_tokens_dim cur_res["routing_method"] = args.routing_method cur_res["use_shuffled_weight"] = use_shuffled_weight cur_res["weight_layout"] = weight_layout @@ -1185,7 +1175,6 @@ def testTrtllmFp8BlockScaleMoe(args): ) local_expert_offset = args.local_expert_offset local_num_experts = args.local_num_experts or num_experts - tile_tokens_dim = args.tile_tokens_dim routing_method_type = args.routing_method_type use_shuffled_weight = args.use_shuffled_weight weight_layout = args.weight_layout @@ -1277,27 +1266,6 @@ def testTrtllmFp8BlockScaleMoe(args): print(f"[VVERBOSE] gemm1_weights_fp8.shape = {gemm1_weights_fp8.shape}") print(f"[VVERBOSE] gemm2_weights_fp8.shape = {gemm2_weights_fp8.shape}") - # Match test heuristic for tile_tokens_dim when using BlockMajorK - if use_shuffled_weight and weight_layout == WeightLayout.BlockMajorK: - - def _next_pow2(x: int) -> int: - x = max(1, x) - x -= 1 - x |= x >> 1 - x |= x >> 2 - x |= x >> 4 - x |= x >> 8 - x |= x >> 16 - return x + 1 - - tokens_per_expert = max(1, (num_tokens * top_k) // max(local_num_experts, 1)) - suggested_tile = min(max(_next_pow2(tokens_per_expert), 8), 64) - if suggested_tile != tile_tokens_dim and args.verbose >= 1: - print( - f"[INFO] Overriding tile_tokens_dim {tile_tokens_dim} -> {suggested_tile} for BlockMajorK" - ) - tile_tokens_dim = suggested_tile - def run_fp8_block_moe(): # Quantize hidden states to FP8 for block scale MOE hidden_states_fp8 = hidden_states.to(torch.float8_e4m3fn) @@ -1320,7 +1288,6 @@ def run_fp8_block_moe(): local_expert_offset=local_expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, - tile_tokens_dim=tile_tokens_dim, routing_method_type=routing_method_type, use_shuffled_weight=use_shuffled_weight, weight_layout=weight_layout, @@ -1381,7 +1348,6 @@ def run_fp8_block_moe(): cur_res["routed_scaling_factor"] = routed_scaling_factor cur_res["local_expert_offset"] = local_expert_offset cur_res["local_num_experts"] = local_num_experts - cur_res["tile_tokens_dim"] = tile_tokens_dim cur_res["routing_method"] = args.routing_method cur_res["use_shuffled_weight"] = use_shuffled_weight cur_res["weight_layout"] = weight_layout @@ -1448,7 +1414,6 @@ def testTrtllmFp8PerTensorScaleMoe(args): ) local_expert_offset = args.local_expert_offset local_num_experts = args.local_num_experts or num_experts - tile_tokens_dim = args.tile_tokens_dim routing_method_type = args.routing_method_type use_routing_scales_on_input = args.use_routing_scales_on_input is_cuda_graph_compatible = not args.no_cuda_graph @@ -1527,7 +1492,6 @@ def run_fp8_per_tensor_moe(): local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, use_routing_scales_on_input=use_routing_scales_on_input, - tile_tokens_dim=tile_tokens_dim, routing_method_type=routing_method_type, ) @@ -1585,7 +1549,6 @@ def run_fp8_per_tensor_moe(): cur_res["routed_scaling_factor"] = routed_scaling_factor cur_res["local_expert_offset"] = local_expert_offset cur_res["local_num_experts"] = local_num_experts - cur_res["tile_tokens_dim"] = tile_tokens_dim cur_res["routing_method"] = args.routing_method cur_res["use_routing_bias"] = args.use_routing_bias cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input diff --git a/benchmarks/samples/sample_testlist_output.csv b/benchmarks/samples/sample_testlist_output.csv index d856d37ab0..b07c523ecb 100644 --- a/benchmarks/samples/sample_testlist_output.csv +++ b/benchmarks/samples/sample_testlist_output.csv @@ -1,4 +1,4 @@ -routine,median_time,std_time,tflops,tb_per_sec,backend,page_size,batch_size,s_qo,s_kv,num_qo_heads,num_kv_heads,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,m,n,k,group_size,tile_size,scale_major_mode,out_dtype,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,hidden_size,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,tile_tokens_dim,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,input_dtype,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command +routine,median_time,std_time,tflops,tb_per_sec,backend,page_size,batch_size,s_qo,s_kv,num_qo_heads,num_kv_heads,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,m,n,k,group_size,tile_size,scale_major_mode,out_dtype,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,hidden_size,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,input_dtype,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command BatchPrefillWithPagedKVCacheWrapper,0.01244799979031086,0.0009464459008260536,13.963516944729905,0.3050282827732261,fa2,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B BatchPrefillWithPagedKVCacheWrapper,0.01839040070772171,0.00021363710731210026,9.45155349045863,0.20646597430613514,cudnn,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B BatchPrefillWithPagedKVCacheWrapper,0.008396799862384795,5.550615129103214e-05,20.70048814413847,0.45219512936224815,trtllm-gen,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B diff --git a/benchmarks/samples/sample_testlist_output.txt b/benchmarks/samples/sample_testlist_output.txt index 69a3961f87..d2c5cc4fa1 100644 --- a/benchmarks/samples/sample_testlist_output.txt +++ b/benchmarks/samples/sample_testlist_output.txt @@ -292,7 +292,7 @@ 2025-09-23 00:32:18,247 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cutlass_autotun:: median time 0.009 ms; std 0.000 ms; achieved tflops 6.372 TFLOPs/sec; achieved tb_per_sec 0.401 TB/sec [PERF] trtllm_autotune:: median time 0.011 ms; std 0.000 ms; achieved tflops 5.410 TFLOPs/sec; achieved tb_per_sec 0.340 TB/sec -[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testTrtllmFp4BlockScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -303,7 +303,7 @@ [VVERBOSE] gemm1_weights_fp4.shape = torch.Size([256, 2048, 512]) [VVERBOSE] gemm2_weights_fp4.shape = torch.Size([256, 1024, 512]) [PERF] trtllm :: median time 0.224 ms; std 0.000 ms; achieved tflops 230.555 TFLOPs/sec; achieved tb_per_sec 1.818 TB/sec -[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=8, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='renormalize_naive', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=4, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=8, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='renormalize_naive', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=4, gated_act_type=0) [INFO] Running testTrtllmFp4BlockScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -314,7 +314,7 @@ [VVERBOSE] gemm1_weights_fp4.shape = torch.Size([128, 2048, 512]) [VVERBOSE] gemm2_weights_fp4.shape = torch.Size([128, 1024, 512]) [PERF] trtllm :: median time 0.226 ms; std 0.000 ms; achieved tflops 227.846 TFLOPs/sec; achieved tb_per_sec 0.903 TB/sec -[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testTrtllmFp8BlockScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -325,7 +325,7 @@ [VVERBOSE] gemm1_weights_fp8.shape = torch.Size([256, 2048, 1024]) [VVERBOSE] gemm2_weights_fp8.shape = torch.Size([256, 1024, 1024]) [PERF] trtllm :: median time 0.557 ms; std 0.000 ms; achieved tflops 92.607 TFLOPs/sec; achieved tb_per_sec 1.455 TB/sec -[INFO] args = Namespace(routine='trtllm_fp8_per_tensor_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='llama4', use_shuffled_weight=False, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=True, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=3, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp8_per_tensor_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='llama4', use_shuffled_weight=False, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=True, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=3, gated_act_type=0) [INFO] Running testTrtllmFp8PerTensorScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -336,7 +336,7 @@ [VVERBOSE] gemm1_weights_fp8.shape = torch.Size([128, 2048, 1024]) [VVERBOSE] gemm2_weights_fp8.shape = torch.Size([128, 1024, 1024]) [PERF] trtllm :: median time 0.123 ms; std 0.000 ms; achieved tflops 52.340 TFLOPs/sec; achieved tb_per_sec 3.299 TB/sec -[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='renormalize', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=1, gated_act_type=0) +[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='renormalize', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=1, gated_act_type=0) [INFO] Running testTrtllmFp8BlockScaleMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' @@ -347,7 +347,7 @@ [VVERBOSE] gemm1_weights_fp8.shape = torch.Size([128, 2048, 1024]) [VVERBOSE] gemm2_weights_fp8.shape = torch.Size([128, 1024, 1024]) [PERF] trtllm :: median time 0.109 ms; std 0.000 ms; achieved tflops 59.297 TFLOPs/sec; achieved tb_per_sec 3.740 TB/sec -[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_base', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, intermediate_size=128, num_experts=2, top_k=2, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='float16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_base', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, intermediate_size=128, num_experts=2, top_k=2, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='float16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testCutlassFusedMoe [INFO] FlashInfer version: 0.3.1 [VVERBOSE] gpu_name = 'NVIDIA_B200' diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index f3c45e2ec0..fc6393237f 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1386,8 +1386,6 @@ Tensor trtllm_fp8_per_tensor_scale_moe( auto launcher = std::make_unique( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar); - // Note: Original code passes tile_N where tile_tokens_dim is expected - // This seems incorrect but we match the original behavior launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout, use_routing_scales_on_input); @@ -1470,8 +1468,6 @@ Tensor trtllm_fp8_block_scale_moe( auto launcher = std::make_unique( routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale); - // Note: Original code passes tile_N where tile_tokens_dim is expected - // This seems incorrect but we match the original behavior launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout); diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 83f186673b..b4444aa431 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1952,7 +1952,6 @@ def trtllm_fp8_per_tensor_scale_moe( local_num_experts: int, routed_scaling_factor: Optional[float], use_routing_scales_on_input: bool, - tile_tokens_dim: Optional[int] = None, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, @@ -1977,7 +1976,6 @@ def trtllm_fp8_per_tensor_scale_moe( local_num_experts: Number of experts handled by this device routed_scaling_factor: Scaling factor for routing use_routing_scales_on_input: Whether to use routing scales on input - tile_tokens_dim: Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) @@ -1985,12 +1983,6 @@ def trtllm_fp8_per_tensor_scale_moe( Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ - if tile_tokens_dim is not None: - logger.warning_once( - "tile_tokens_dim in trtllm_fp8_per_tensor_scale_moe is planned for deprecation " - "in a future release. Please remove it from your code as tile_tokens_dim will no " - "longer be supported after v0.5.0." - ) return get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, @@ -2032,7 +2024,6 @@ def trtllm_fp8_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int] = None, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, @@ -2058,19 +2049,12 @@ def trtllm_fp8_block_scale_moe( local_expert_offset: Offset of local experts in global expert space local_num_experts: Number of experts handled by this device routed_scaling_factor: Scaling factor for routing - tile_tokens_dim: Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ - if tile_tokens_dim is not None: - logger.warning_once( - "tile_tokens_dim in trtllm_fp8_block_scale_moe is planned for deprecation " - "in a future release. Please remove it from your code as tile_tokens_dim will no " - "longer be supported after v0.5.0." - ) output = torch.empty( hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device ) @@ -2125,7 +2109,6 @@ def trtllm_fp4_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int], routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, @@ -2176,7 +2159,6 @@ def trtllm_fp4_block_scale_moe( local_expert_offset (int): Offset of local experts in global expert space local_num_experts (int): Number of experts handled by this device routed_scaling_factor (Optional[float]): Scaling factor for routing (can be None for some routing methods) - tile_tokens_dim (Optional[int]): Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type (int): Type of routing method to use (default: 0) - 0: Default (Softmax -> TopK) - 1: Renormalize (TopK -> Softmax) @@ -2195,12 +2177,6 @@ def trtllm_fp4_block_scale_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ - if tile_tokens_dim is not None: - logger.warning_once( - "tile_tokens_dim in trtllm_fp4_block_scale_moe is planned for deprecation " - "in a future release. Please remove it from your code as tile_tokens_dim will no " - "longer be supported after v0.5.0." - ) return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( routing_logits, None, @@ -2262,7 +2238,6 @@ def trtllm_fp4_block_scale_routed_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int], routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, @@ -2315,7 +2290,6 @@ def trtllm_fp4_block_scale_routed_moe( local_expert_offset (int): Offset of local experts in global expert space local_num_experts (int): Number of experts handled by this device routed_scaling_factor (Optional[float]): Scaling factor for routing (can be None for some routing methods) - tile_tokens_dim (Optional[int]): Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type (int): Type of routing method to use (default: 0) - 0: Default (Softmax -> TopK) - 1: Renormalize (TopK -> Softmax) @@ -2334,12 +2308,6 @@ def trtllm_fp4_block_scale_routed_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ - if tile_tokens_dim is not None: - logger.warning_once( - "tile_tokens_dim in trtllm_fp4_block_scale_routed_moe is planned for deprecation " - "in a future release. Please remove it from your code as tile_tokens_dim will no " - "longer be supported after v0.5.0." - ) return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( None, topk_ids, diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 747946fc09..35f4ad61e7 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -208,7 +208,6 @@ def _run_moe_computation(self, runtime_args): local_expert_offset=0, local_num_experts=self.config["num_experts"], routed_scaling_factor=self.config["routed_scaling"], - tile_tokens_dim=None, routing_method_type=self.config["routing_method_type"], gated_act_type=self.config["gated_act_type"], do_finalize=True, @@ -799,7 +798,6 @@ def call_moe( 0, num_experts, routed_scaling, - None, routing_method_type, use_shuffled_weight=static_data["use_shuffled_weight"], weight_layout=static_data["weight_layout"], @@ -979,7 +977,6 @@ def call_moe( routed_scaling, routing_method_type == RoutingMethodType.Llama4, # Use_routing_scales_on_input - None, routing_method_type, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, ) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index be39bda225..fb3feba4b7 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -180,7 +180,6 @@ def test_trtllm_gen_routed_fused_moe( 0, # local_expert_offset num_experts, None, # routed_scaling_factor - None, # tile_tokens_dim routing_method_type.value, True, # do_finalize enable_pdl, @@ -234,7 +233,6 @@ def test_trtllm_gen_routed_fused_moe( 0, # local_expert_offset num_experts, None, # routed_scaling_factor - None, # tile_tokens_dim routing_method_type.value, True, # do_finalize enable_pdl, From 636a3abbb96e363c65a609266147c88219332f57 Mon Sep 17 00:00:00 2001 From: Aditya K Kamath <12785368+AKKamath@users.noreply.github.com> Date: Thu, 13 Nov 2025 19:55:34 -0800 Subject: [PATCH 057/130] [Feature] Support batch prefill for POD Attention (#2079) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: @Edenzzzz ## ๐Ÿ“Œ Description Fixes https://github.com/flashinfer-ai/flashinfer/issues/1022. Unlike https://github.com/flashinfer-ai/flashinfer/pull/1231, this splits the inputs into separate prefill and decode inputs. It probably should be possible to automatically handle this splitting in Python so you can simply just provide a single batch of requests? To run the benchmark for this run: `python benchmarks/bench_mixed_attention.py` Performance: ===== Benchmark 1: (kv_len, qo_len) set ===== Prefill = 2 requests, 2048 Q len, 2048 KV len Decode = 128 requests, 2048 KV len Elapsed time (Batched Prefill): 0.65 ms Elapsed time (Batched POD Attention): 0.46 ms Elapsed time (Persistent BatchAttention): 0.56 ms **Batch POD speedup over Persistent BatchAttention: 1.22x** ===== Benchmark 2: (kv_len, qo_len) set ===== Prefill = 1 request, 2048 Q len, 2048 KV len Decode = 128 requests, 2048 KV len Elapsed time (Batched Prefill): 0.55 ms Elapsed time (Batched POD Attention): 0.41 ms Elapsed time (POD Attention): 0.41 ms Elapsed time (Sequential two kernels): 0.51 ms Elapsed time (Persistent BatchAttention): 0.45 ms **Batch POD speedup over Persistent BatchAttention: 1.11x** ===== Benchmark 3: (kv_len, qo_len) set ===== Prefill = 1 request, 4096 Q len, 4096 KV len Decode = 128 requests, 4096 KV len Elapsed time (Batched Prefill): 1.27 ms Elapsed time (Batched POD Attention): 0.86 ms Elapsed time (POD Attention): 0.82 ms Elapsed time (Sequential two kernels): 1.15 ms Elapsed time (Persistent BatchAttention): 1.08 ms Batch POD speedup over Persistent BatchAttention: 1.26x ===== Benchmark 4: (kv_len, qo_len) set ===== Prefill = 1 request, 4096 Q len, 4096 KV len Decode = 128 requests, 8192 KV len Elapsed time (Batched Prefill): 2.15 ms Elapsed time (Batched POD Attention): 1.52 ms Elapsed time (POD Attention): 1.54 ms Elapsed time (Sequential two kernels): 1.82 ms Elapsed time (Persistent BatchAttention): 1.76 ms **Batch POD speedup over Persistent BatchAttention: 1.16x** ===== Benchmark 5: (kv_len, qo_len) set ===== Prefill = 1 request, 6000 Q len, 7000 KV len Decode = 128 requests, 8192 KV len Elapsed time (Batched Prefill): 2.86 ms Elapsed time (Batched POD Attention): 2.03 ms Elapsed time (POD Attention): 1.95 ms Elapsed time (Sequential two kernels): 2.52 ms Elapsed time (Persistent BatchAttention): 2.45 ms **Batch POD speedup over Persistent BatchAttention: 1.20x** ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Added a batched prefill+decode attention path with a public batch-oriented POD wrapper and JIT module export. * **Performance** * Benchmarks extended to include batched-path timings, memory bandwidth, elapsed-time and comparative speedup metrics across expanded prefill/decode scenarios. * **API** * Runtime binding for batched KVโ€‘cache execution added; planning APIs now accept an optional colocated-CTA parameter that influences scheduling. --------- Co-authored-by: Aditya K Kamath Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Edenzzzz --- benchmarks/bench_mixed_attention.py | 102 +++- csrc/batch_pod.cu | 350 ++++++++++++ csrc/batch_pod_customize_config.jinja | 43 ++ csrc/batch_pod_jit_binding.cu | 44 ++ csrc/batch_pod_kernel_inst.jinja | 31 ++ csrc/batch_prefill.cu | 5 +- csrc/batch_prefill_jit_binding.cu | 2 +- csrc/pod_jit_binding.cu | 2 +- flashinfer/__init__.py | 1 + flashinfer/decode.py | 4 +- flashinfer/jit/__init__.py | 1 + flashinfer/jit/attention/__init__.py | 1 + flashinfer/jit/attention/modules.py | 167 +++++- flashinfer/pod.py | 584 ++++++++++++++++++++- flashinfer/prefill.py | 2 + flashinfer/sparse.py | 2 + include/flashinfer/attention/batch_pod.cuh | 394 ++++++++++++++ include/flashinfer/attention/scheduler.cuh | 6 +- 18 files changed, 1725 insertions(+), 16 deletions(-) create mode 100644 csrc/batch_pod.cu create mode 100644 csrc/batch_pod_customize_config.jinja create mode 100644 csrc/batch_pod_jit_binding.cu create mode 100644 csrc/batch_pod_kernel_inst.jinja create mode 100644 include/flashinfer/attention/batch_pod.cuh diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index 9bb6616737..7414a58af0 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -23,7 +23,10 @@ def run_bench( q_lens = torch.tensor(d_qo_lens + p_qo_lens, dtype=torch.int32) seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() - d_seq_lens_blocks = ( + p_seq_lens_blocks = torch.ceil( + torch.tensor(p_kv_lens, dtype=torch.int32) / page_block_size + ).int() + d_seq_lens_blocks = torch.ceil( torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size ).int() @@ -31,6 +34,14 @@ def run_bench( kv_indptr = torch.cat( [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 ).int() + + p_q_indptr = torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0 + ).int() + p_kv_indptr = torch.cat( + [torch.tensor([0]), torch.cumsum(p_seq_lens_blocks, 0)], dim=0 + ).int() + d_q_indptr = torch.cat( [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 ).int() @@ -46,7 +57,7 @@ def run_bench( device, dtype=torch.bfloat16 ) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + workspace_buffer = torch.empty(156 * 1024 * 1024, dtype=torch.uint8, device=device) kv_layout = "NHD" wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( @@ -90,7 +101,67 @@ def run_bench( o_persistent, _ = wrapper_persistent.run(q, kv_data) measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data)) ms_persistent = np.mean(measurements_persistent) + + # Batched POD Attention + q_d = q[: d_q_indptr[-1]] + kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) + q_p = q[d_q_indptr[-1] :] + kv_p = kv_data[d_kv_indptr[-1] :].unbind(1) + kv_indices_d = torch.arange(0, d_kv_indptr[-1], device=device, dtype=torch.int32) + kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32) + + last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1 + last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1 + wrapper_pod = flashinfer.BatchPODWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout=kv_layout, + ) + + wrapper_pod.plan( + # Prefill params + p_q_indptr.to(device), + p_kv_indptr.to(device), + kv_indices_p.to(device), + last_page_len_p, + # Decode params + d_q_indptr.to(device), + d_kv_indptr.to(device), + kv_indices_d.to(device), + last_page_len_d, + # Common params + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_block_size, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + o_p_batch, o_d_batch = wrapper_pod.run( + q_p, + kv_p, + q_d, + kv_d, + causal_p=causal, + ) + o_batch_pod = torch.cat([o_d_batch, o_p_batch], dim=0) + + # Verify output matches + torch.testing.assert_close( + o_batch_pod, o, rtol=4e-3, atol=4e-3, msg="Batch POD-Attention decode mismatch!" + ) + measurements = bench_gpu_time( + lambda: wrapper_pod.run( + q_p, + kv_p, + q_d, + kv_d, + causal_p=causal, + ) + ) + ms_batch_pod = np.median(measurements) + if len(p_kv_lens) == 1: + # Single POD attention q_d = q[: d_q_indptr[-1]] kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) q_p = q[d_q_indptr[-1] :] @@ -127,7 +198,7 @@ def run_bench( o_pod = torch.cat([o_d, o_p], dim=0) # Verify output matches torch.testing.assert_close( - o, o_pod, rtol=1e-3, atol=1e-3, msg="POD-Attention output mismatch!" + o, o_pod, rtol=4e-3, atol=4e-3, msg="POD-Attention output mismatch!" ) measurements = bench_gpu_time( lambda: wrapper_pod.run( @@ -177,10 +248,15 @@ def _run_single_prefill(): ms_seq_two_kernels = ms_prefill + ms_decode print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") + print(f"Elapsed time (Batched POD Attention): {ms_batch_pod:.2f} ms") if len(p_kv_lens) == 1: print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") print(f"Elapsed time (Sequential two kernels): {ms_seq_two_kernels:.2f} ms") print(f"Elapsed time (Persistent BatchAttention): {ms_persistent:.2f} ms") + print( + f"Batch POD speedup over Persistent BatchAttention: {ms_persistent / ms_batch_pod:.2f}x" + ) + total_bytes = ( q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() ) @@ -189,6 +265,10 @@ def _run_single_prefill(): bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3) print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s") + bandwidth_batch_pod_gb_s = total_bytes / (ms_batch_pod * 1e-3) / (1024**3) + print( + f"Memory bandwidth (Batched POD Attention): {bandwidth_batch_pod_gb_s:.2f} GB/s" + ) if len(p_kv_lens) == 1: bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3) print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s") @@ -207,10 +287,18 @@ def _run_single_prefill(): torch.random.manual_seed(42) # Irregular sequence lengths for prefill and decode - d_q_len_configs = [[1] * 128, [1] * 128, [1] * 128, [1] * 128] - d_kv_len_configs = [[2048] * 128, [4096] * 128, [8192] * 128, [8192] * 128] - p_q_configs = [[2048], [4096], [4096], [6000]] - p_kv_configs = [[2048], [4096], [4096], [7000]] + d_q_len_configs = [[1] * 128] * 7 + d_kv_len_configs = [ + [2048] * 128, + [2048] * 128, + [2048] * 128, + [2048] * 128, + [4096] * 128, + [8192] * 128, + [8192] * 128, + ] + p_q_configs = [[512], [1536], [2048] * 2, [2048], [4096], [4096], [6000]] + p_kv_configs = [[512], [1536], [2048] * 2, [2048], [4096], [4096], [7000]] page_block_size = 1 num_kv_heads = 8 diff --git a/csrc/batch_pod.cu b/csrc/batch_pod.cu new file mode 100644 index 0000000000..33aa5e753f --- /dev/null +++ b/csrc/batch_pod.cu @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "batch_pod_config.inc" +#include "tvm_ffi_utils.h" + +namespace flashinfer { +template +cudaError_t BatchPODWithKVCacheTensorDispatched(PrefillParams prefill_params, + typename PrefillParams::DTypeO* tmp_v_p, + float* tmp_s_p, DecodeParams decode_params, + typename DecodeParams::DTypeO* tmp_v_d, + float* tmp_s_d, bool enable_pdl, + cudaStream_t stream, int* sm_aware_sched); + +} // namespace flashinfer + +using namespace flashinfer; + +using tvm::ffi::Array; +using tvm::ffi::Optional; + +void batch_pod_with_kv_cache_tensor( + // Prefill params + TensorView float_workspace_buffer_p, TensorView int_workspace_buffer_p, + Array plan_info_vec_p, TensorView q_p, TensorView paged_k_cache_p, + TensorView paged_v_cache_p, TensorView qo_indptr_p, TensorView paged_kv_indptr_p, + TensorView paged_kv_indices_p, TensorView paged_kv_last_page_len_p, TensorView o_p, + Optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, + int64_t window_left_p, Optional maybe_custom_mask_p, + Optional maybe_mask_indptr_p, Optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, + // Decode params + TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, + Array plan_info_vec_d, TensorView q_d, TensorView paged_k_cache_d, + TensorView paged_v_cache_d, TensorView qo_indptr_d, TensorView paged_kv_indptr_d, + TensorView paged_kv_indices_d, TensorView paged_kv_last_page_len_d, TensorView o_d, + Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, + int64_t window_left_d, Optional maybe_custom_mask_d, + Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, + double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, + bool enable_pdl, TensorView sm_aware_sched) { + // Prefill setup + PrefillPlanInfo plan_info_p; + plan_info_p.FromVector(std::vector(plan_info_vec_p.begin(), plan_info_vec_p.end())); + QKVLayout kv_layout_p = static_cast(layout_p); + int64_t batch_size_p = paged_kv_indptr_p.size(0) - 1; + int64_t num_qo_heads = q_p.size(1); + + int64_t num_kv_heads_p, page_size_p; + uint32_t head_dim_qk_p = q_p.size(2); + if (kv_layout_p == QKVLayout::kHND) { + num_kv_heads_p = paged_k_cache_p.size(1); + page_size_p = paged_k_cache_p.size(2); + } else { + page_size_p = paged_k_cache_p.size(1); + num_kv_heads_p = paged_k_cache_p.size(2); + } + + if (maybe_lse_p.has_value()) { + const auto& lse = maybe_lse_p.value(); + TVM_FFI_ICHECK_EQ(lse.size(0), q_p.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q_p.size(1)); + } + + void* float_buffer_ptr_p = static_cast(float_workspace_buffer_p.data_ptr()); + void* int_buffer_ptr_p = static_cast(int_workspace_buffer_p.data_ptr()); + + const MaskMode mask_mode_p = static_cast(mask_mode_code_p); + + // get q_stride_n and q_stride_h + const auto q_stride_n_p = q_p.stride(0); + const auto q_stride_h_p = q_p.stride(1); + + // get kv_cache_strides + const int64_t* kv_cache_strides_p = nullptr; + auto k_strides_p = paged_k_cache_p.strides(); + auto v_strides_p = paged_v_cache_p.strides(); + TVM_FFI_ICHECK_EQ(k_strides_p.size(), v_strides_p.size()); + for (int i = 0; i < k_strides_p.size(); ++i) { + TVM_FFI_ICHECK_EQ(k_strides_p[i], v_strides_p[i]); + } + kv_cache_strides_p = k_strides_p.data(); + + cudaSetDevice(float_workspace_buffer_p.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer_p.device()); + + // Decode setup (TensorView decode = batched prefill) + PrefillPlanInfo plan_info_d; + plan_info_d.FromVector(std::vector(plan_info_vec_d.begin(), plan_info_vec_d.end())); + QKVLayout kv_layout_d = static_cast(layout_d); + int64_t batch_size_d = paged_kv_indptr_d.size(0) - 1; + int64_t num_qo_heads_d = q_d.size(1); + + TVM_FFI_ICHECK_EQ(num_qo_heads, num_qo_heads_d) + << "POD currently requires same # Query heads for prefill and decode"; + + int64_t num_kv_heads_d, page_size_d; + uint32_t head_dim_qk_d = q_d.size(2); + if (kv_layout_d == QKVLayout::kHND) { + num_kv_heads_d = paged_k_cache_d.size(1); + page_size_d = paged_k_cache_d.size(2); + } else { + page_size_d = paged_k_cache_d.size(1); + num_kv_heads_d = paged_k_cache_d.size(2); + } + TVM_FFI_ICHECK_EQ(num_kv_heads_p, num_kv_heads_d) + << "POD currently requires same # KV heads for prefill and decode; Prefill: " + << num_kv_heads_p << ", Decode: " << num_kv_heads_d; + + if (maybe_lse_d.has_value()) { + const auto& lse = maybe_lse_d.value(); + TVM_FFI_ICHECK_EQ(lse.size(0), q_d.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q_d.size(1)); + } + + void* float_buffer_ptr_d = static_cast(float_workspace_buffer_d.data_ptr()); + void* int_buffer_ptr_d = static_cast(int_workspace_buffer_d.data_ptr()); + + const MaskMode mask_mode_d = static_cast(mask_mode_code_d); + + // get q_stride_n and q_stride_h + const auto q_stride_n_d = q_d.stride(0); + const auto q_stride_h_d = q_d.stride(1); + + // get kv_cache_strides + const int64_t* kv_cache_strides_d = nullptr; + auto k_strides_d = paged_k_cache_d.strides(); + auto v_strides_d = paged_v_cache_d.strides(); + TVM_FFI_ICHECK_EQ(k_strides_d.size(), v_strides_d.size()); + for (int i = 0; i < k_strides_d.size(); ++i) { + TVM_FFI_ICHECK_EQ(k_strides_d[i], v_strides_d[i]); + } + kv_cache_strides_d = k_strides_d.data(); + + // Already handled by prefill + // cudaSetDevice(float_workspace_buffer_d.device().device_id); + // const cudaStream_t stream = get_stream(float_workspace_buffer_d.device()); + + DISPATCH_context( + MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, USE_SLIDING_WINDOW_P, + USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, [&] { + PrefillParams prefill_params; + DTypeO* tmp_v_p = nullptr; + float* tmp_s_p = nullptr; + { + PrefillParams& params = prefill_params; + params.q = static_cast(q_p.data_ptr()); + paged_kv_t paged_kv( + num_kv_heads_p, page_size_p, HEAD_DIM_VO, batch_size_p, kv_layout_p, + static_cast(paged_k_cache_p.data_ptr()), + static_cast(paged_v_cache_p.data_ptr()), kv_cache_strides_p, + static_cast(paged_kv_indices_p.data_ptr()), + static_cast(paged_kv_indptr_p.data_ptr()), + static_cast(paged_kv_last_page_len_p.data_ptr())); + params.paged_kv = paged_kv; + params.q_indptr = static_cast(qo_indptr_p.data_ptr()); + params.o = static_cast(o_p.data_ptr()); + + params.lse = maybe_lse_p.has_value() ? static_cast(maybe_lse_p.value().data_ptr()) + : nullptr; + params.num_qo_heads = num_qo_heads; + params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); + params.q_stride_n = q_stride_n_p; + params.q_stride_h = q_stride_h_p; + params.window_left = window_left_p; + + params.request_indices = nullptr; + params.qo_tile_indices = nullptr; + params.kv_tile_indices = nullptr; + params.merge_indptr = nullptr; + params.o_indptr = nullptr; + params.kv_chunk_size_ptr = nullptr; + params.block_valid_mask = nullptr; + params.total_num_rows = nullptr; + params.max_total_num_rows = 0; + params.padded_batch_size = 0; + params.partition_kv = false; + + params.maybe_mask_indptr = + maybe_mask_indptr_p.has_value() + ? static_cast(maybe_mask_indptr_p.value().data_ptr()) + : nullptr; + params.maybe_alibi_slopes = + maybe_alibi_slopes_p.has_value() + ? static_cast(maybe_alibi_slopes_p.value().data_ptr()) + : nullptr; + params.logits_soft_cap = logits_soft_cap_p; + params.sm_scale = sm_scale_p; + params.rope_rcp_scale = rope_rcp_scale_p; + params.rope_rcp_theta = rope_rcp_theta_p; + + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.kv_tile_indices_offset); + params.o_indptr = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.kv_chunk_size_ptr_offset); + if (plan_info_p.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.merge_indptr_offset); + tmp_v_p = GetPtrFromBaseOffset(float_buffer_ptr_p, plan_info_p.v_offset); + tmp_s_p = GetPtrFromBaseOffset(float_buffer_ptr_p, plan_info_p.s_offset); + if (plan_info_p.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.block_valid_mask_offset); + } + } + params.padded_batch_size = plan_info_p.padded_batch_size; + params.max_total_num_rows = plan_info_p.total_num_rows; + if (plan_info_p.enable_cuda_graph) { + params.total_num_rows = + GetPtrFromBaseOffset(int_buffer_ptr_p, plan_info_p.total_num_rows_offset); + } + } + + DecodeParams decode_params; + DTypeO* tmp_v_d = nullptr; + float* tmp_s_d = nullptr; + { + DecodeParams& params = decode_params; + params.q = static_cast(q_d.data_ptr()); + paged_kv_t paged_kv( + num_kv_heads_d, page_size_d, HEAD_DIM_VO, batch_size_d, kv_layout_d, + static_cast(paged_k_cache_d.data_ptr()), + static_cast(paged_v_cache_d.data_ptr()), kv_cache_strides_d, + static_cast(paged_kv_indices_d.data_ptr()), + static_cast(paged_kv_indptr_d.data_ptr()), + static_cast(paged_kv_last_page_len_d.data_ptr())); + params.paged_kv = paged_kv; + params.q_indptr = static_cast(qo_indptr_d.data_ptr()); + params.o = static_cast(o_d.data_ptr()); + + params.lse = maybe_lse_d.has_value() ? static_cast(maybe_lse_d.value().data_ptr()) + : nullptr; + params.num_qo_heads = num_qo_heads; + params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); + params.q_stride_n = q_stride_n_d; + params.q_stride_h = q_stride_h_d; + params.window_left = window_left_d; + + params.request_indices = nullptr; + params.qo_tile_indices = nullptr; + params.kv_tile_indices = nullptr; + params.merge_indptr = nullptr; + params.o_indptr = nullptr; + params.kv_chunk_size_ptr = nullptr; + params.block_valid_mask = nullptr; + params.total_num_rows = nullptr; + params.max_total_num_rows = 0; + params.padded_batch_size = 0; + params.partition_kv = false; + + params.maybe_mask_indptr = + maybe_mask_indptr_d.has_value() + ? static_cast(maybe_mask_indptr_d.value().data_ptr()) + : nullptr; + params.maybe_alibi_slopes = + maybe_alibi_slopes_d.has_value() + ? static_cast(maybe_alibi_slopes_d.value().data_ptr()) + : nullptr; + params.logits_soft_cap = logits_soft_cap_d; + params.sm_scale = sm_scale_d; + params.rope_rcp_scale = rope_rcp_scale_d; + params.rope_rcp_theta = rope_rcp_theta_d; + + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.kv_tile_indices_offset); + params.o_indptr = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.kv_chunk_size_ptr_offset); + if (plan_info_d.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.merge_indptr_offset); + tmp_v_d = GetPtrFromBaseOffset(float_buffer_ptr_d, plan_info_d.v_offset); + tmp_s_d = GetPtrFromBaseOffset(float_buffer_ptr_d, plan_info_d.s_offset); + if (plan_info_d.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.block_valid_mask_offset); + } + } + params.padded_batch_size = plan_info_d.padded_batch_size; + params.max_total_num_rows = plan_info_d.total_num_rows; + if (plan_info_d.enable_cuda_graph) { + params.total_num_rows = + GetPtrFromBaseOffset(int_buffer_ptr_d, plan_info_d.total_num_rows_offset); + } + } + + constexpr bool use_custom_mask_p = MASK_MODE_P == MaskMode::kCustom; + using PrefillAttentionVariant = + DefaultAttention; + constexpr bool use_custom_mask_d = MASK_MODE_D == MaskMode::kCustom; + using DecodeAttentionVariant = + DefaultAttention; + + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + int num_sm = 0; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + // SM-aware scheduling buffer uses num_sm + 2 entries + // num_sm entries for counters for each SM, and + // 2 entries for keeping track of blockIds for prefill and decode + assert( + sm_aware_sched.ndim() == 1 && sm_aware_sched.size(0) == num_sm + 2 && + "sm_aware_sched tensor has incorrect shape or type, should be (num_sm + 2,) of int32"); + DISPATCH_CTA_TILE_Q(plan_info_p.cta_tile_q, CTA_TILE_Q_P, { + constexpr size_t CTA_TILE_Q_D = 16; + cudaError_t status = flashinfer::BatchPODWithKVCacheTensorDispatched< + HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, CTA_TILE_Q_P, + MASK_MODE_P, CTA_TILE_Q_D, MASK_MODE_D, PrefillAttentionVariant, + DecodeAttentionVariant>(prefill_params, tmp_v_p, tmp_s_p, decode_params, tmp_v_d, + tmp_s_d, enable_pdl, stream, + static_cast(sm_aware_sched.data_ptr())); + TVM_FFI_ICHECK(status == cudaSuccess) + << "BatchPODWithKVCache kernel launch failed, error: " << cudaGetErrorString(status); + return status; + }); + }); +} diff --git a/csrc/batch_pod_customize_config.jinja b/csrc/batch_pod_customize_config.jinja new file mode 100644 index 0000000000..9f27b42953 --- /dev/null +++ b/csrc/batch_pod_customize_config.jinja @@ -0,0 +1,43 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace flashinfer; + +using DTypeQ = {{ dtype_q }}; +using DTypeKV = {{ dtype_kv }}; +using DTypeO = {{ dtype_o }}; +using IdType = {{ idtype }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; +constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; + +constexpr auto USE_LOGITS_SOFT_CAP_P = {{ use_logits_soft_cap_p }}; +constexpr auto POS_ENCODING_MODE_P = {{ pos_encoding_mode_p }}; +constexpr auto USE_SLIDING_WINDOW_P = {{ use_sliding_window_p }}; + +constexpr auto USE_LOGITS_SOFT_CAP_D = {{ use_logits_soft_cap_d }}; +constexpr auto POS_ENCODING_MODE_D = {{ pos_encoding_mode_d }}; +constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }}; + +constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; +constexpr bool USE_LOGITS_SOFT_CAP = false; + +using PrefillParams = BatchPrefillPagedParams; +using DecodeParams = BatchPrefillPagedParams; + +#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ + USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \ + DISPATCH_MASK_MODE(mask_mode_p, MASK_MODE_P, { \ + DISPATCH_MASK_MODE(mask_mode_d, MASK_MODE_D, { \ + __VA_ARGS__(); \ + }); \ +}); diff --git a/csrc/batch_pod_jit_binding.cu b/csrc/batch_pod_jit_binding.cu new file mode 100644 index 0000000000..c7a8a5ea6b --- /dev/null +++ b/csrc/batch_pod_jit_binding.cu @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "batch_pod_config.inc" +#include "tvm_ffi_utils.h" + +using tvm::ffi::Array; +using tvm::ffi::Optional; + +void batch_pod_with_kv_cache_tensor( + // Prefill params + TensorView float_workspace_buffer_p, TensorView int_workspace_buffer_p, + Array plan_info_vec_p, TensorView q_p, TensorView paged_k_cache_p, + TensorView paged_v_cache_p, TensorView qo_indptr_p, TensorView paged_kv_indptr_p, + TensorView paged_kv_indices_p, TensorView paged_kv_last_page_len_p, TensorView o_p, + Optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, + int64_t window_left_p, Optional maybe_custom_mask_p, + Optional maybe_mask_indptr_p, Optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, + // Decode params + TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, + Array plan_info_vec_d, TensorView q_d, TensorView paged_k_cache_d, + TensorView paged_v_cache_d, TensorView qo_indptr_d, TensorView paged_kv_indptr_d, + TensorView paged_kv_indices_d, TensorView paged_kv_last_page_len_d, TensorView o_d, + Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, + int64_t window_left_d, Optional maybe_custom_mask_d, + Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, + double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, + bool enable_pdl, TensorView sm_aware_sched); + +// Batch-request prefill attention with KV-Cache operator +TVM_FFI_DLL_EXPORT_TYPED_FUNC(batch_pod_with_kv_cache_tensor, batch_pod_with_kv_cache_tensor); diff --git a/csrc/batch_pod_kernel_inst.jinja b/csrc/batch_pod_kernel_inst.jinja new file mode 100644 index 0000000000..cb2c39d32b --- /dev/null +++ b/csrc/batch_pod_kernel_inst.jinja @@ -0,0 +1,31 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "batch_pod_config.inc" + +using namespace flashinfer; + +namespace flashinfer { +constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom; +constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom; +// Not sure about the below declaration +constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + +{% for cta_tile_q in [16, 64, 128] %} +template cudaError_t BatchPODWithKVCacheTensorDispatched< + {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, + {{ use_fp16_qk_reduction }}, /*CTA_TILE_Q_P=*/{{cta_tile_q}}, {{ mask_mode_p }}, + /*CTA_TILE_Q_D=*/16, {{ mask_mode_d }}, {{ variant_name_p }}, + {{ variant_name_d }}, PrefillParams, DecodeParams>( + PrefillParams prefill_params, {{ dtype_o }}* tmp_v_p, float *tmp_s_p, + DecodeParams decode_params, {{ dtype_o }}* tmp_v_d, float *tmp_s_d, + bool enable_pdl, cudaStream_t stream, int* sm_aware_sched); +{% endfor %} +} diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index 5d7182bdc5..9e0d77582f 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -50,7 +50,7 @@ Array BatchPrefillWithKVCachePlan( TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, - bool disable_split_kv) { + bool disable_split_kv, int64_t num_colocated_ctas = 0) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -66,7 +66,8 @@ Array BatchPrefillWithKVCachePlan( int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), static_cast(kv_indptr.data_ptr()), total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, - /*sizeof_dtype_o=*/2, window_left, fixed_split_size, disable_split_kv, stream); + /*sizeof_dtype_o=*/2, window_left, fixed_split_size, disable_split_kv, num_colocated_ctas, + stream); TVM_FFI_ICHECK(status == cudaSuccess) << "Failed to plan prefill with error: " << cudaGetErrorString(status); diff --git a/csrc/batch_prefill_jit_binding.cu b/csrc/batch_prefill_jit_binding.cu index da1e1981dc..3dda0f115a 100644 --- a/csrc/batch_prefill_jit_binding.cu +++ b/csrc/batch_prefill_jit_binding.cu @@ -25,7 +25,7 @@ Array BatchPrefillWithKVCachePlan( TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, - bool disable_split_kv); + bool disable_split_kv, int64_t num_colocated_ctas); void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer, Array plan_info_vec, diff --git a/csrc/pod_jit_binding.cu b/csrc/pod_jit_binding.cu index 915e4bcdbf..1da0bf7bae 100644 --- a/csrc/pod_jit_binding.cu +++ b/csrc/pod_jit_binding.cu @@ -37,5 +37,5 @@ void pod_with_kv_cache_tensor( double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); -// Batch-request prefill attention with KV-Cache operator +// Single prefill, Batch-request decode attention with KV-Cache operator TVM_FFI_DLL_EXPORT_TYPED_FUNC(pod_with_kv_cache_tensor, pod_with_kv_cache_tensor); diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 8cedc9261e..faad4f12a3 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -100,6 +100,7 @@ from .page import get_batch_indices_positions as get_batch_indices_positions from .page import get_seq_lens as get_seq_lens from .pod import PODWithPagedKVCacheWrapper as PODWithPagedKVCacheWrapper +from .pod import BatchPODWithPagedKVCacheWrapper as BatchPODWithPagedKVCacheWrapper from .prefill import ( BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper, ) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 574f8a024c..5826e743da 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1060,6 +1060,7 @@ def plan( window_left, fixed_split_size, disable_split_kv, + 0, # num_colocated_ctas ) else: if self._jit_module is not None: @@ -2920,7 +2921,7 @@ def fast_decode_plan( kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) try: - # Make sure we pass exactly 15 arguments for tensor core version + # Make sure we pass exactly 16 arguments for tensor core version self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -2940,6 +2941,7 @@ def fast_decode_plan( window_left, fixed_split_size, disable_split_kv, + 0, # num_colocated_ctas ) except Exception as e: raise RuntimeError(f"Error in standard plan: {e}") from e diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index bc4132ec9c..1aa6f44dbd 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -42,6 +42,7 @@ gen_customize_single_prefill_module as gen_customize_single_prefill_module, ) from .attention import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module +from .attention import gen_batch_pod_module as gen_batch_pod_module from .attention import gen_pod_module as gen_pod_module from .attention import gen_single_decode_module as gen_single_decode_module from .attention import gen_single_prefill_module as gen_single_prefill_module diff --git a/flashinfer/jit/attention/__init__.py b/flashinfer/jit/attention/__init__.py index 2ae6f30729..583c8d1615 100644 --- a/flashinfer/jit/attention/__init__.py +++ b/flashinfer/jit/attention/__init__.py @@ -33,6 +33,7 @@ gen_customize_single_prefill_module as gen_customize_single_prefill_module, ) from .modules import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module +from .modules import gen_batch_pod_module as gen_batch_pod_module from .modules import gen_pod_module as gen_pod_module from .modules import gen_single_decode_module as gen_single_decode_module from .modules import gen_single_prefill_module as gen_single_prefill_module diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index 475acdcd1e..fe895def12 100644 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -630,6 +630,71 @@ def gen_pod_module( ) +def gen_batch_pod_module( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + head_dim: int, + pos_encoding_mode_p: int, + use_sliding_window_p: bool, + use_logits_soft_cap_p: bool, + use_fp16_qk_reduction: bool, + dtype_idx: torch.dtype, + pos_encoding_mode_d: int, + use_sliding_window_d: bool, + use_logits_soft_cap_d: bool, +) -> JitSpec: + uri = "batch_" + get_pod_uri( + dtype_q, + dtype_kv, + dtype_o, + head_dim, + pos_encoding_mode_p, + use_sliding_window_p, + use_logits_soft_cap_p, + use_fp16_qk_reduction, + dtype_idx, + pos_encoding_mode_d, + use_sliding_window_d, + use_logits_soft_cap_d, + ) + additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"] + additional_tensor_dtypes = ["uint8_t", "float"] + additional_scalar_names = [ + "logits_soft_cap", + "sm_scale", + "rope_rcp_scale", + "rope_rcp_theta", + ] + additional_scalar_dtypes = ["float", "float", "float", "float"] + variant_name_p = f"DefaultAttention" + variant_name_d = f"DefaultAttention" + variant_decl = "#include" + + return gen_customize_batch_pod_module( + uri, + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim, + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + variant_name_p, + variant_name_d, + variant_decl, + pos_encoding_mode_p=pos_encoding_mode_p, + use_sliding_window_p=use_sliding_window_p, + use_logits_soft_cap_p=use_logits_soft_cap_p, + pos_encoding_mode_d=pos_encoding_mode_d, + use_sliding_window_d=use_sliding_window_d, + use_logits_soft_cap_d=use_logits_soft_cap_d, + use_fp16_qk_reduction=use_fp16_qk_reduction, + ) + + def gen_customize_pod_module( uri: str, dtype_q: torch.dtype, @@ -698,6 +763,8 @@ def gen_customize_pod_module( ) os.makedirs(gen_directory, exist_ok=True) + generated_config_path = gen_directory / "pod_config.inc" + write_if_different(generated_config_path, generated_inc_str) source_paths = [] @@ -725,8 +792,106 @@ def gen_customize_pod_module( source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "pod_config.inc" + return gen_jit_spec(uri, source_paths) + + +def gen_customize_batch_pod_module( + uri: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim: int, + additional_tensor_names: List[str], + additional_tensor_dtypes: List[str], + additional_scalar_names: List[str], + additional_scalar_dtypes: List[str], + variant_name_p: str, + variant_name_d: str, + variant_decl: str, + pos_encoding_mode_p: int = 0, + use_sliding_window_p: bool = False, + use_logits_soft_cap_p: bool = False, + pos_encoding_mode_d: int = 0, + use_sliding_window_d: bool = False, + use_logits_soft_cap_d: bool = False, + use_fp16_qk_reduction: bool = False, +) -> JitSpec: + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri + + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + ) + + with open(jit_env.FLASHINFER_CSRC_DIR / "batch_pod_customize_config.jinja") as f: + config_templ = jinja2.Template(f.read()) + + with open(jit_env.FLASHINFER_CSRC_DIR / "batch_pod_kernel_inst.jinja") as f: + kernel_inst_templ = jinja2.Template(f.read()) + + kwargs = { + "additional_func_params": additional_func_params, + "additional_params_decl": additional_params_decl, + "additional_params_setter": additional_params_setter, + "variant_decl": variant_decl, + "variant_name_p": variant_name_p, + "variant_name_d": variant_name_d, + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "idtype": dtype_map[dtype_idx], + "head_dim_qk": head_dim, + "head_dim_vo": head_dim, + "pos_encoding_mode_p": pos_encoding_mode_literal[pos_encoding_mode_p], + "pos_encoding_mode_d": pos_encoding_mode_literal[pos_encoding_mode_d], + "use_sliding_window_p": str(use_sliding_window_p).lower(), + "use_logits_soft_cap_p": str(use_logits_soft_cap_p).lower(), + "use_sliding_window_d": str(use_sliding_window_d).lower(), + "use_logits_soft_cap_d": str(use_logits_soft_cap_d).lower(), + "use_fp16_qk_reduction": str(use_fp16_qk_reduction).lower(), + } + + generated_inc_str = config_templ.render( + **kwargs, + ) + + os.makedirs(gen_directory, exist_ok=True) + generated_config_path = gen_directory / "batch_pod_config.inc" write_if_different(generated_config_path, generated_inc_str) + + source_paths = [] + + for mask_mode_p in [0, 1, 2, 3]: + for mask_mode_d in [0, 1, 2, 3]: + kwargs["mask_mode_p"] = mask_mode_literal[mask_mode_p] + kwargs["mask_mode_d"] = mask_mode_literal[mask_mode_d] + + filename = f"batch_pod_kernel_mask_{mask_mode_p}p_{mask_mode_d}d.cu" + dest_path = gen_directory / filename + source_paths.append(dest_path) + source = kernel_inst_templ.render( + **kwargs, + ) + write_if_different(dest_path, source) + + for filename in [ + "batch_pod.cu", + "batch_pod_jit_binding.cu", + ]: + src_path = jit_env.FLASHINFER_CSRC_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) + return gen_jit_spec(uri, source_paths) diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 59e113f238..d0a66f7ae9 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -21,7 +21,7 @@ import torch -from .jit import gen_pod_module +from .jit import gen_pod_module, gen_batch_pod_module from .page import get_seq_lens from .prefill import get_batch_prefill_module from .quantization import packbits @@ -47,6 +47,12 @@ def get_pod_module(*args): return SimpleNamespace(run_tensor=module.pod_with_kv_cache_tensor) +@functools.cache +def get_batch_pod_module(*args): + module = gen_batch_pod_module(*args).build_and_load() + return SimpleNamespace(run_tensor=module.batch_pod_with_kv_cache_tensor) + + class PODWithPagedKVCacheWrapper: r"""Wrapper class for POD-Attention with paged kv-cache (first proposed in ``_) for batch of requests. @@ -413,6 +419,7 @@ def plan( window_left, -1, # fixed_split_size False, # disable_split_kv + 0, # num_colocated_ctas ) self._indptr_type = indptr.dtype @@ -610,3 +617,578 @@ def run( def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" pass + + +class BatchPODWithPagedKVCacheWrapper: + r"""Wrapper class for POD-Attention with paged kv-cache (first proposed in + ``_) for batch of requests. + + Check :ref:`our tutorial` for page table layout. + + Examples + -------- + >>> import torch + >>> import flashinfer + >>> num_layers = 8 + >>> num_qo_heads = 64 + >>> num_kv_heads = 8 + >>> head_dim = 128 + >>> max_num_pages = 128 + >>> device = 0 + >>> page_block_size = 1 + >>> causal = True + >>> # allocate 128MB workspace buffer + >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") + >>> wrapper = flashinfer.BatchPODWithPagedKVCacheWrapper( + ... workspace_buffer, "NHD" + ... ) + >>> # Prefill and decode parameters + >>> p_qo_lens = [2048] * 2 + >>> d_qo_lens = [1] * 128 + >>> p_kv_lens = [2048] * 2 + >>> d_kv_lens = [2048] * 128 + >>> # Prefill plan inputs + >>> p_seq_lens_blocks = torch.ceil( + ... torch.tensor(p_kv_lens, dtype=torch.int32) / page_block_size + ... ).int() + >>> p_q_indptr = torch.cat( + ... [torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0 + ... ).int() + >>> p_kv_indptr = torch.cat( + ... [torch.tensor([0]), torch.cumsum(p_seq_lens_blocks, 0)], dim=0 + ... ).int() + >>> kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32) + >>> last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1 + >>> # Decode plan inputs + >>> d_seq_lens_blocks = torch.ceil( + ... torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size + ... ).int() + >>> d_q_indptr = torch.cat( + ... [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 + ... ).int() + >>> d_kv_indptr = torch.cat( + ... [torch.tensor([0]), torch.cumsum(d_seq_lens_blocks, 0)], dim=0 + ... ).int() + >>> kv_indices_d = torch.arange(0, d_kv_indptr[-1], device=device, dtype=torch.int32) + >>> last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1 + >>> # create auxiliary data structures for batch decode attention + >>> wrapper.plan( + ... # Prefill params + ... p_q_indptr.to(device), + ... p_kv_indptr.to(device), + ... kv_indices_p.to(device), + ... last_page_len_p, + ... # Decode params + ... d_q_indptr.to(device), + ... d_kv_indptr.to(device), + ... kv_indices_d.to(device), + ... last_page_len_d, + ... # Common params + ... num_qo_heads=num_qo_heads, + ... num_kv_heads=num_kv_heads, + ... head_dim=head_dim, + ... page_size=page_block_size, + ... q_data_type=torch.bfloat16, + ... kv_data_type=torch.bfloat16, + ... ) + >>> # Prefill input tensors + >>> q_p = torch.rand(p_q_indptr[-1].item(), num_qo_heads, head_dim).to( + ... device, dtype=torch.bfloat16 + ... ) + >>> kv_p = torch.randn(p_kv_indptr[-1], 2, page_block_size, num_kv_heads, head_dim).to( + ... device, dtype=torch.bfloat16 + ... ).unbind(1) + >>> # Decode input tensors + >>> q_d = torch.rand(d_q_indptr[-1].item(), num_qo_heads, head_dim).to( + ... device, dtype=torch.bfloat16 + ... ) + >>> kv_d = torch.randn(d_kv_indptr[-1], 2, page_block_size, num_kv_heads, head_dim).to( + ... device, dtype=torch.bfloat16 + ... ).unbind(1) + >>> for i in range(num_layers): + ... o_p_batch, o_d_batch = wrapper.run( + ... q_p, + ... kv_p, + ... q_d, + ... kv_d, + ... causal_p=causal, + ... ) + >>> print(o_p_batch.shape, o_d_batch.shape) + torch.Size([4096, 64, 128]) torch.Size([128, 64, 128]) + + Note + ---- + To accelerate computation, FlashInfer's POD-Attention creates some + auxiliary data structures, these data structures can be reused across multiple + batch decode attention calls (e.g. different Transformer layers). This wrapper class + manages the lifecycle of these data structures. + """ + + def __init__( + self, + float_workspace_buffer: torch.Tensor, + kv_layout: str = "NHD", + ) -> None: + r"""Constructor of :class:`BatchPODWithPagedKVCacheWrapper`. + + Parameters + ---------- + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace + buffer should be the same as the device of the input tensors. + + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + + """ + _check_kv_layout(kv_layout) + # Override options. Only tensor core version is performant. + use_tensor_cores = True + self._jit_module: SimpleNamespace = None + + self._kv_layout = kv_layout + float_workspace_buffer_p, float_workspace_buffer_d = torch.chunk( + float_workspace_buffer, 2, dim=0 + ) + self._float_workspace_buffer_p = float_workspace_buffer_p + self._float_workspace_buffer_d = float_workspace_buffer_d + self.device = float_workspace_buffer_p.device + self._int_workspace_buffer_p = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + self._int_workspace_buffer_d = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + self._pin_memory_int_workspace_buffer_p = torch.empty( + (8 * 1024 * 1024,), + dtype=torch.uint8, + pin_memory=True, + device="cpu", + ) + self._pin_memory_int_workspace_buffer_d = torch.empty( + (8 * 1024 * 1024,), + dtype=torch.uint8, + pin_memory=True, + device="cpu", + ) + + # SM aware scheduling buffer, requires SMs count + 2 entries + dev_prop = torch.cuda.get_device_properties(self.device) + self._sm_aware_sched = torch.empty( + (dev_prop.multi_processor_count + 2), dtype=torch.int, device=self.device + ) + + self._fixed_batch_size = 0 + + self._paged_kv_indptr_buf = None + self._paged_kv_indices_buf = None + self._paged_kv_last_page_len_buf = None + self._use_tensor_cores = use_tensor_cores + self._use_cuda_graph = False + + @property + def is_cuda_graph_enabled(self) -> bool: + return self._use_cuda_graph + + def plan( + self, + qo_indptr_p: torch.Tensor, + kv_indptr_p: torch.Tensor, + kv_indices_p: torch.Tensor, + last_page_len_p: torch.Tensor, + qo_indptr_d: torch.Tensor, + kv_indptr_d: torch.Tensor, + kv_indices_d: torch.Tensor, + last_page_len_d: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + pos_encoding_mode: str = "NONE", + window_left: int = -1, + q_data_type: Optional[Union[str, torch.dtype]] = "float16", + kv_data_type: Optional[Union[str, torch.dtype]] = None, + data_type: Optional[Union[str, torch.dtype]] = None, + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + non_blocking: bool = True, + ) -> None: + r"""Plan POD's batch prefill and decode for given problem specification. + + Parameters + ---------- + qo_indptr_p : torch.Tensor + The prefill indptr of the query/output tensor, shape: ``[batch_size + 1]``. + kv_indptr_p : torch.Tensor + The prefill indptr of the paged kv-cache, shape: ``[batch_size + 1]``. + kv_indices_p : torch.Tensor + The prefill page indices of the paged kv-cache, shape: ``[kv_indptr[-1]]``. + last_page_len_p : torch.Tensor + The number of entries in the last page of each prefill request in the paged + kv-cache, shape: ``[batch_size]``. + qo_indptr_d : torch.Tensor + The decode indptr of the query/output tensor, shape: ``[batch_size + 1]``. + kv_indptr_d : torch.Tensor + The decode indptr of the paged kv-cache, shape: ``[batch_size + 1]``. + kv_indices_d : torch.Tensor + The decode page indices of the paged kv-cache, shape: ``[kv_indptr[-1]]``. + last_page_len_d : torch.Tensor + The number of entries in the last page of each decode request in the paged + kv-cache, shape: ``[batch_size]``. + num_qo_heads : int + The number of query/output heads + num_kv_heads : int + The number of key/value heads + head_dim : int + The dimension of the heads + page_size : int + The page size of the paged kv cache + pos_encoding_mode : str + The position encoding applied inside attention kernels, could be + ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Defaults to ``NONE``. + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. + q_data_type : Optional[Union[str, torch.dtype]] + The data type of the query tensor, defaults torch.float16. + kv_data_type : Optional[Union[str, torch.dtype]] + The data type of the key/value tensor. If None, will be set to + ``q_data_type``. Defaults to ``None``. + data_type: Optional[Union[str, torch.dtype]] + The data type of both the query and key/value tensors. Defaults to torch.float16. + data_type is deprecated, please use q_data_type and kv_data_type instead. + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to + ``1.0 / sqrt(head_dim_qk)``. + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + non_blocking : bool + Whether to copy the input tensors to the device asynchronously, defaults to ``True``. + + Note + ---- + The :meth:`plan` method should be called before any :meth:`run` or + :meth:`run_return_lse` calls, auxiliary data structures will be created + during this call and cached for multiple run calls. + + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` + is not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. + + The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. + """ + # Logits soft cap is not supported currently + logits_soft_cap = 0.0 + + # Setup prefill params + batch_size_p = len(last_page_len_p) + qo_indptr_host_p = qo_indptr_p.to("cpu") + total_num_rows_p = int(qo_indptr_host_p[-1]) + self._kv_indptr_buf_p = kv_indptr_p.to(self.device, non_blocking=non_blocking) + self._kv_indices_buf_p = kv_indices_p.to(self.device, non_blocking=non_blocking) + self._kv_last_page_len_buf_p = last_page_len_p.to( + self.device, non_blocking=non_blocking + ) + self._qo_indptr_buf_p = qo_indptr_host_p.to( + self.device, non_blocking=non_blocking + ) + kv_indptr_host_p = kv_indptr_p.to("cpu") + last_page_len_host_p = last_page_len_p.to("cpu") + kv_lens_arr_host_p = get_seq_lens( + kv_indptr_host_p, last_page_len_host_p, page_size + ) + + if data_type is not None: + if q_data_type is None: + q_data_type = data_type + if kv_data_type is None: + kv_data_type = data_type + + q_data_type = canonicalize_torch_dtype(q_data_type) + if kv_data_type is None: + kv_data_type = q_data_type + kv_data_type = canonicalize_torch_dtype(kv_data_type) + + self._cached_q_data_type = q_data_type + self._cached_kv_data_type = kv_data_type + if self._jit_module is not None: + self._cached_module = self._jit_module + else: + self._cached_module = get_batch_prefill_module( + "fa2", + q_data_type, + kv_data_type, + q_data_type, + kv_indptr_p.dtype, + head_dim, # head_dim_qk + head_dim, # head_dim_vo + PosEncodingMode[pos_encoding_mode].value, + window_left != -1, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + False, # use_fp16_qk_reduction + ) + + # Setup decode params + batch_size_d = len(last_page_len_d) + qo_indptr_host_d = qo_indptr_d.to("cpu") + total_num_rows_d = int(qo_indptr_host_d[-1]) + self._kv_indptr_buf_d = kv_indptr_d.to(self.device, non_blocking=non_blocking) + self._kv_indices_buf_d = kv_indices_d.to(self.device, non_blocking=non_blocking) + self._kv_last_page_len_buf_d = last_page_len_d.to( + self.device, non_blocking=non_blocking + ) + self._qo_indptr_buf_d = qo_indptr_host_d.to( + self.device, non_blocking=non_blocking + ) + kv_indptr_host_d = kv_indptr_d.to("cpu") + last_page_len_host_d = last_page_len_d.to("cpu") + kv_lens_arr_host_d = get_seq_lens( + kv_indptr_host_d, last_page_len_host_d, page_size + ) + + self._plan_info_d = self._cached_module.plan( + self._float_workspace_buffer_d, + self._int_workspace_buffer_d, + self._pin_memory_int_workspace_buffer_d, + qo_indptr_host_d, + kv_indptr_host_d, + kv_lens_arr_host_d, + total_num_rows_d, # total_num_rows + batch_size_d, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + window_left, + -1, # fixed_split_size + False, # disable_split_kv + 0, # num_colocated_ctas + ) + + num_colocated_ctas = self._plan_info_d[0] + # Splitting small prefill causes unecessary bandwidth contention + if total_num_rows_p > 1536: + num_colocated_ctas = 0 + self._plan_info_p = self._cached_module.plan( + self._float_workspace_buffer_p, + self._int_workspace_buffer_p, + self._pin_memory_int_workspace_buffer_p, + qo_indptr_host_p, + kv_indptr_host_p, + kv_lens_arr_host_p, + total_num_rows_p, # total_num_rows + batch_size_p, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + window_left, + -1, # fixed_split_size + False, # disable_split_kv + num_colocated_ctas, + ) + self._indptr_type = kv_indptr_p.dtype + self._pos_encoding_mode = pos_encoding_mode + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta + + begin_forward = plan + + def run( + self, + # Main params (prefill and decode) + q_p: torch.Tensor, + paged_kv_cache_p: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q_d: torch.Tensor, + paged_kv_cache_d: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + # Prefill options + custom_mask_p: Optional[torch.Tensor] = None, + packed_custom_mask_p: Optional[torch.Tensor] = None, + causal_p: bool = False, + # Decode options + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + # Common options + return_lse: bool = False, + use_fp16_qk_reduction: bool = False, + enable_pdl: Optional[bool] = None, + ) -> Union[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], + ]: + r"""Compute POD-attention for a batch of requests.""" + if enable_pdl is None: + enable_pdl = device_support_pdl(q_p.device) + + # Currently unsupported + logits_soft_cap_p = None + logits_soft_cap_d = None + # Prefill setup + k_cache_p, v_cache_p = _unpack_paged_kv_cache(paged_kv_cache_p, self._kv_layout) + _check_cached_qkv_data_type( + q_p, k_cache_p, self._cached_q_data_type, self._cached_kv_data_type + ) + # Get params from plan + pos_encoding_mode_p = self._pos_encoding_mode + window_left_p = self._window_left + logits_soft_cap_p = self._logits_soft_cap + sm_scale_p = self._sm_scale + rope_scale_p = self._rope_scale + rope_theta_p = self._rope_theta + _check_pos_encoding_mode(pos_encoding_mode_p) + if logits_soft_cap_p is None: + logits_soft_cap_p = 0.0 + if sm_scale_p is None: + head_dim = q_p.shape[-1] + sm_scale_p = 1.0 / math.sqrt(head_dim) + if rope_scale_p is None: + rope_scale_p = 1.0 + if rope_theta_p is None: + rope_theta_p = 1e4 + + if custom_mask_p is not None and packed_custom_mask_p is None: + # create packed custom mask from custom mask + packed_custom_mask_p = packbits( + custom_mask_p.contiguous().view(-1), bitorder="little" + ) + + if packed_custom_mask_p is not None: + mask_mode_p = MaskMode.CUSTOM.value + else: + if causal_p: + mask_mode_p = MaskMode.CAUSAL.value + else: + mask_mode_p = MaskMode.NON_CAUSAL.value + + lse_p = None + if return_lse: + lse_p = torch.empty( + (q_p.size(0), q_p.size(1)), dtype=torch.float32, device=q_p.device + ) + out_p = torch.empty_like(q_p) + + # Decode setup + k_cache_d, v_cache_d = _unpack_paged_kv_cache(paged_kv_cache_d, self._kv_layout) + _check_cached_qkv_data_type( + q_d, k_cache_d, self._cached_q_data_type, self._cached_kv_data_type + ) + # Get params from plan + pos_encoding_mode_d = self._pos_encoding_mode + window_left_d = self._window_left + logits_soft_cap_d = self._logits_soft_cap + sm_scale_d = self._sm_scale + rope_scale_d = self._rope_scale + rope_theta_d = self._rope_theta + _check_pos_encoding_mode(pos_encoding_mode_d) + if logits_soft_cap_d is None: + logits_soft_cap_d = 0.0 + if sm_scale_d is None: + head_dim = q_d.shape[-1] + sm_scale_d = 1.0 / math.sqrt(head_dim) + if q_scale is not None: + sm_scale_d *= q_scale + if k_scale is not None: + sm_scale_d *= k_scale + if rope_scale_d is None: + rope_scale_d = 1.0 + if rope_theta_d is None: + rope_theta_d = 1e4 + + lse_d = None + if return_lse: + lse_d = torch.empty( + (q_d.size(0), q_d.size(1)), dtype=torch.float32, device=q_d.device + ) + out_d = torch.empty_like(q_d) + + module_getter = get_batch_pod_module( + # Prefill params + q_p.dtype, + k_cache_p.dtype, + q_p.dtype, + q_p.shape[-1], + PosEncodingMode[pos_encoding_mode_p].value, + window_left_p >= 0, # use_sliding_window + logits_soft_cap_p > 0, # use_logits_soft_cap + use_fp16_qk_reduction, + # Decode params + self._indptr_type, + PosEncodingMode[pos_encoding_mode_d].value, + window_left_d != -1, # use_sliding_window + logits_soft_cap_d > 0, # use_logits_soft_cap + ) + module_getter.run_tensor( + # Prefill params + self._float_workspace_buffer_p, + self._int_workspace_buffer_p, + self._plan_info_p, + q_p, + k_cache_p, + v_cache_p, + self._qo_indptr_buf_p, + self._kv_indptr_buf_p, + self._kv_indices_buf_p, + self._kv_last_page_len_buf_p, + out_p, + lse_p, + mask_mode_p, + TensorLayout[self._kv_layout].value, + window_left_p, + packed_custom_mask_p, # packed_custom_mask + None, # mask_indptr_buf + _get_cache_alibi_slopes_buf(q_p.shape[1], q_p.device), + logits_soft_cap_p, + sm_scale_p, + 1.0 / rope_scale_p, + 1.0 / rope_theta_p, + # Decode params + self._float_workspace_buffer_d, + self._int_workspace_buffer_d, + self._plan_info_d, + q_d, + k_cache_d, + v_cache_d, + self._qo_indptr_buf_d, + self._kv_indptr_buf_d, + self._kv_indices_buf_d, + self._kv_last_page_len_buf_d, + out_d, + lse_d, + MaskMode.NON_CAUSAL.value, + TensorLayout[self._kv_layout].value, + window_left_d, + None, # packed_custom_mask + None, # mask_indptr_buf + _get_cache_alibi_slopes_buf(q_d.shape[1], q_d.device), + logits_soft_cap_d, + sm_scale_d, + 1.0 / rope_scale_d, + 1.0 / rope_theta_d, + enable_pdl, + self._sm_aware_sched, + ) + + if v_scale is not None: + out_d *= v_scale + + return ((out_p, out_d), (lse_p, lse_d)) if return_lse else (out_p, out_d) + + def end_forward(self) -> None: + r"""Warning: this function is deprecated and has no effect.""" + pass diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 49abe60897..6b4353011f 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1902,6 +1902,7 @@ def plan( if self._backend == "fa2": args.append(fixed_split_size or -1) # fixed_split_size args.append(disable_split_kv) # disable_split_kv + args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, ) @@ -2769,6 +2770,7 @@ def plan( if self._backend == "fa2": args.append(fixed_split_size or -1) # fixed_split_size args.append(disable_split_kv) # disable_split_kv + args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, ) diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 36e26bb684..37a3d444b7 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -473,6 +473,7 @@ def plan( if self._backend == "fa2": args.append(-1) # fixed_split_size args.append(False) # disable_split_kv + args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, ) @@ -1062,6 +1063,7 @@ def _block_mask_map_to_expanded_indices( if self._backend == "fa2": args.append(-1) # fixed_split_size args.append(False) # disable_split_kv + args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, ) diff --git a/include/flashinfer/attention/batch_pod.cuh b/include/flashinfer/attention/batch_pod.cuh new file mode 100644 index 0000000000..d8e0e12985 --- /dev/null +++ b/include/flashinfer/attention/batch_pod.cuh @@ -0,0 +1,394 @@ +#ifndef FLASHINFER_BATCH_POD_CUH_ +#define FLASHINFER_BATCH_POD_CUH_ + +#include +#include +#include +#include +#include + +#include "../cp_async.cuh" +#include "../fastdiv.cuh" +#include "../frag_layout_swizzle.cuh" +#include "../layout.cuh" +#include "../math.cuh" +#include "../mma.cuh" +#include "../page.cuh" +#include "../permuted_smem.cuh" +#include "../pos_enc.cuh" +#include "../utils.cuh" +#include "cascade.cuh" +#include "decode.cuh" +#include "mask.cuh" +#include "prefill.cuh" +#include "variants.cuh" + +namespace flashinfer { + +namespace cg = cooperative_groups; +using cp_async::SharedMemFillMode; +using mma::MMAMode; + +enum Operation { + PREFILL = 0, + DECODE = 1, +}; + +template +__global__ __launch_bounds__(std::max( + KTraits_P::NUM_THREADS, + KTraits_D::NUM_THREADS)) void BatchPODWithKVCacheTensorKernel(const __grid_constant__ + PrefillParams prefill_params, + const __grid_constant__ + DecodeParams decode_params, + int* sm_aware_sched) { + extern __shared__ uint8_t smem[]; + // PREFILL VARS + const uint32_t padded_bsize_p = prefill_params.padded_batch_size; + const uint32_t num_kv_heads_p = prefill_params.paged_kv.num_heads; + + // DECODE VARS + const uint32_t padded_bsize_d = decode_params.padded_batch_size; + const uint32_t num_kv_heads_d = decode_params.paged_kv.num_heads; + + // THREADBLOCKS + const uint32_t prefill_blocks = padded_bsize_p * num_kv_heads_p; + const uint32_t decode_blocks = padded_bsize_d * num_kv_heads_d; + + int op; + int linear_bid; + // SM-aware CTA scheduler + if (threadIdx.x == 0) { + // TODO_AK: If num_threads dont match, use virtual sub-CTAs. + // Requires changing block-level sync in main prefill/decode kernels. + constexpr int blk_factor_p = 1; + constexpr int blk_factor_d = 1; + + // SM-aware threadblock scheduler code + // Find out which SM this threadblock is scheduled on + int num_SMs; + // WARNING: nsmid has only been tested on A100/H100, and matches SM count + // No guarantee this will work on other GPUs + asm volatile("mov.u32 %0, %nsmid;" : "=r"(num_SMs)); + asm volatile("mov.u32 %0, %smid;" : "=r"(linear_bid)); + const int prefill_slots = (prefill_blocks + blk_factor_p - 1) / blk_factor_p; + const int decode_slots = (decode_blocks + blk_factor_d - 1) / blk_factor_d; + + if (prefill_slots <= decode_slots) { + // Total tags = (decode + prefill) / min(decode, prefill) + // = 1 + decode / prefill; when prefill < decode + const int total_tags = decode_slots / prefill_slots + 1; + // For this SM, what's the next operation we want to run? + op = (atomicAdd(&sm_aware_sched[linear_bid], 1) % total_tags); + if (op > 0) { + op = 1; + } + } else { + // Total tags = (decode + prefill) / min(decode, prefill) + // = 1 + prefill / decode; when decode < prefill + const int pref_tags = prefill_slots / decode_slots; + + // For this SM, what's the next operation we want to run? + op = (atomicAdd(&sm_aware_sched[linear_bid], 1) % (pref_tags + 1)); + if (op < pref_tags) { + op = 0; + } else { + op = 1; + } + } + + // Get the next blockId for that operation + linear_bid = atomicAdd(&sm_aware_sched[num_SMs + op], 1); + // If the blockId obtained exceeds the max blockIds for that op, switch to the other op + if (op == 0 && linear_bid >= prefill_slots) { + linear_bid = atomicAdd(&sm_aware_sched[num_SMs + 1], 1); + op = !op; + } else if (op == 1 && linear_bid >= decode_slots) { + op = !op; + linear_bid = atomicAdd(&sm_aware_sched[num_SMs + 0], 1); + } + // Write the blockId and operation to shared memory + ((int*)smem)[0] = linear_bid; + ((int*)smem)[1] = op; + } + // Sync to wait for dynamic scheduler to finish + __syncthreads(); + // Fetch from shared memory the assigned blockId and operation. + linear_bid = ((int*)smem)[0]; + op = ((int*)smem)[1]; + // Sync to force all threads to wait + __syncthreads(); + + if (op == PREFILL) { + auto& smem_storage = reinterpret_cast(smem); + // dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads); + if (linear_bid >= prefill_blocks) return; + + const uint32_t bx = linear_bid % padded_bsize_p; + const uint32_t kv_head_idx = linear_bid / padded_bsize_p; + + // dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D); + const uint32_t linear_tid = threadIdx.x; + // Return if threadId exceeds number of threads for this op + if (linear_tid >= 32 * KTraits_P::NUM_WARPS_Q * KTraits_P::NUM_WARPS_KV) return; + + const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_P::NUM_WARPS_Q, + (linear_tid / 32) / KTraits_P::NUM_WARPS_Q); + + BatchPrefillWithPagedKVCacheDevice(prefill_params, smem_storage, tid, bx, + kv_head_idx, num_kv_heads_p); + } else /* OP == DECODE */ { + auto& smem_storage = reinterpret_cast(smem); + // dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads); + if (linear_bid >= decode_blocks) return; + + const uint32_t bx = linear_bid % padded_bsize_d; + const uint32_t kv_head_idx = linear_bid / padded_bsize_d; + + // dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D); + const uint32_t linear_tid = threadIdx.x; + // Return if threadId exceeds number of threads for this op + if (linear_tid >= 32 * KTraits_D::NUM_WARPS_Q * KTraits_D::NUM_WARPS_KV) return; + + const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_D::NUM_WARPS_Q, + (linear_tid / 32) / KTraits_D::NUM_WARPS_Q); + + BatchPrefillWithPagedKVCacheDevice(decode_params, smem_storage, tid, bx, kv_head_idx, + num_kv_heads_d); + } +} + +template +cudaError_t BatchPODWithKVCacheTensorDispatched(PrefillParams prefill_params, + typename PrefillParams::DTypeO* tmp_v_p, + float* tmp_s_p, DecodeParams decode_params, + typename DecodeParams::DTypeO* tmp_v_d, + float* tmp_s_d, bool enable_pdl, + cudaStream_t stream, int* sm_aware_sched) { + static_assert(std::is_same::value); + static_assert( + std::is_same::value); + static_assert(std::is_same::value); + // Ensure heads match + assert(prefill_params.paged_kv.num_heads == decode_params.paged_kv.num_heads); + assert(prefill_params.num_qo_heads == decode_params.num_qo_heads); + // Common variables for both prefill and decode + const uint32_t num_qo_heads = prefill_params.num_qo_heads; + const uint32_t num_kv_heads = prefill_params.paged_kv.num_heads; + + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + int max_smem_per_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, + cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); + + // Prefill variable setup + using DTypeQ_P = typename PrefillParams::DTypeQ; + using DTypeKV_P = typename PrefillParams::DTypeKV; + using DTypeO_P = typename PrefillParams::DTypeO; + const uint32_t padded_batch_size_p = prefill_params.padded_batch_size; + constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P); + constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P); + constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P); + + using DTypeQKAccum_P = + typename std::conditional, half, + float>::type; + + const uint32_t group_size = num_qo_heads / num_kv_heads; + const uint_fastdiv group_size_fastdiv(group_size); + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + + // Decode vars setup + using DTypeQ_D = typename DecodeParams::DTypeQ; + using DTypeKV_D = typename DecodeParams::DTypeKV; + using DTypeO_D = typename DecodeParams::DTypeO; + const uint32_t padded_batch_size_d = decode_params.padded_batch_size; + constexpr uint32_t NUM_MMA_Q_D = get_num_mma_q(CTA_TILE_Q_D); + constexpr uint32_t NUM_WARPS_Q_D = get_num_warps_q(CTA_TILE_Q_D); + constexpr uint32_t NUM_WARPS_KV_D = get_num_warps_kv(CTA_TILE_Q_D); + + // constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + // constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + using DTypeQKAccum_D = + typename std::conditional, half, + float>::type; + + // we expect each sm execute two threadblocks + // TODO(Zihao): fix the following computation + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_D) * 16) ? 2 : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + // Prefill params + constexpr uint32_t max_num_mma_kv_reg_p = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q_P); + const uint32_t max_num_mma_kv_smem_p = + (max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) - + NUM_MMA_Q_P * NUM_WARPS_Q_P) / + (2 * NUM_WARPS_KV_P); + + // Decode params + constexpr uint32_t max_num_mma_kv_reg_d = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q_D == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q_D); + // TODO(Zihao): fix the following computation + const uint32_t max_num_mma_kv_smem_d = + (max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ_D)) - + NUM_MMA_Q_D * NUM_WARPS_Q_D) / + (2 * NUM_WARPS_KV_D); + + // control NUM_MMA_KV for maximum warp occupancy + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, { + using KTraits_P = KernelTraits; + + if constexpr (KTraits_P::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P + << " NUM_WARPS_KV=" << NUM_WARPS_KV_P + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + // Decode stuff + // TODO: Is there a way to avoid this nested dispatch? + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, { + using KTraits_D = + KernelTraits; + if constexpr (KTraits_D::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D + << " NUM_WARPS_KV=" << NUM_WARPS_KV_D + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + // End decode stuff + constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE; + size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage); + size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); + + auto kernel = + BatchPODWithKVCacheTensorKernel; + + // Setup new prefill params if (not) split + auto o_p = prefill_params.o; + auto lse_p = prefill_params.lse; + if (tmp_v_p == nullptr) { + // do not partition kv + prefill_params.partition_kv = false; + } else { + prefill_params.partition_kv = true; + prefill_params.o = tmp_v_p; + prefill_params.lse = tmp_s_p; + } + + // Setup new decode params if (not) split + auto o_d = decode_params.o; + auto lse_d = decode_params.lse; + if (tmp_v_d == nullptr) { + // do not partition kv + decode_params.partition_kv = false; + } else { + decode_params.partition_kv = true; + decode_params.o = tmp_v_d; + decode_params.lse = tmp_s_d; + } + int nblks_p(padded_batch_size_p * 1 * num_kv_heads); + int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); + + int nblks_d(padded_batch_size_d * 1 * num_kv_heads); + int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); + + // ******* Select final combined sizes here ******* / + size_t smem_size = max(smem_size_p, smem_size_d); + int nblks = nblks_p + nblks_d; + int nthrs = max(nthrs_p, nthrs_d); + // ************************************************ / + + int num_sm = 0; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL( + cudaMemsetAsync(sm_aware_sched, 0, sizeof(int) * (num_sm + 2), stream)); + + // Setup kernel arguments + void* args[] = {(void*)&prefill_params, (void*)&decode_params, (void*)&sm_aware_sched}; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Launch kernel + if (enable_pdl) { + cudaLaunchAttribute attribute[1]; + cudaLaunchConfig_t config; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attribute; + config.numAttrs = 1; + config.gridDim = nblks; + config.blockDim = nthrs; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + FLASHINFER_CUDA_CALL( + cudaLaunchKernelEx(&config, kernel, prefill_params, decode_params, sm_aware_sched)); + } else { + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } + + // Post-kernel stuff for split-kv prefill + if (tmp_v_p != nullptr) { + if constexpr (PrefillAttentionVariant::use_softmax) { + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v_p, tmp_s_p, prefill_params.merge_indptr, o_p, lse_p, + prefill_params.max_total_num_rows, prefill_params.total_num_rows, num_qo_heads, + HEAD_DIM_VO, enable_pdl, stream)); + } else { + FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( + tmp_v_p, prefill_params.merge_indptr, o_p, prefill_params.max_total_num_rows, + prefill_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream)); + } + } + // Post-kernel stuff for split-kv decode + if (tmp_v_d != nullptr) { + if constexpr (DecodeAttentionVariant::use_softmax) { + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v_d, tmp_s_d, decode_params.merge_indptr, o_d, lse_d, + decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, + HEAD_DIM_VO, enable_pdl, stream)); + } else { + FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( + tmp_v_d, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows, + decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream)); + } + } + } + }); + } + }); + return cudaSuccess; +} + +} // namespace flashinfer + +#endif // FLASHINFER_BATCH_POD_CUH_ diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 4f888e716b..286023e204 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -443,7 +443,6 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in padded_batch_size = (enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : new_batch_size; plan_info.padded_batch_size = padded_batch_size; - auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] = DecodeSplitKVIndptr(indptr_h, batch_size, kv_chunk_size_in_pages); @@ -700,6 +699,8 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o, int32_t window_left, int32_t fixed_split_size, bool disable_split_kv, + int64_t num_colocated_ctas, // for POD attention, limit prefill + // splits by #colocated decode CTAs cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; @@ -714,7 +715,8 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); int num_blocks_per_sm = 2; - int max_grid_size = num_blocks_per_sm * num_sm; + int64_t available_ctas = static_cast(num_blocks_per_sm) * num_sm - num_colocated_ctas; + int max_grid_size = static_cast(std::max(0, available_ctas)); uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; // step 2: determine kv_chunk_size From 37434ed9f2b5eb4a5ef4869f769ab05b1cac6f8d Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Fri, 14 Nov 2025 11:43:00 -0800 Subject: [PATCH 058/130] feat: patch sm103 for 3xfp4 moe generation (#2082) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Patch sm103 for 3xfp4 moe generation ## ๐Ÿ” Related Issues Following up of #2020 #1925 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ``` $ ls csrc/nv_internal/tensorrt_llm/cutlass_instantiations/103/gemm_grouped 100 103 80 $ pytest tests/moe/test_trtllm_cutlass_fused_moe.py 22 passed, 3 skipped, 1 warning in 771.89s (0:12:51) ``` ## Summary by CodeRabbit * **New Features** * Added support for Blackwell (SM103) GPU architecture in MOE (Mixture of Experts) operations with specialized CUTLASS-optimized modules. --- flashinfer/aot.py | 2 ++ flashinfer/fused_moe/__init__.py | 2 ++ flashinfer/fused_moe/core.py | 5 ++++- flashinfer/jit/fused_moe.py | 18 ++++++++++++++++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 5801cc933b..609e1bcbcf 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -43,6 +43,7 @@ from .jit.fp8_quantization import gen_mxfp8_quantization_sm100_module from .jit.fused_moe import ( gen_cutlass_fused_moe_sm120_module, + gen_cutlass_fused_moe_sm103_module, gen_cutlass_fused_moe_sm100_module, gen_cutlass_fused_moe_sm90_module, gen_trtllm_gen_fused_moe_sm100_module, @@ -495,6 +496,7 @@ def gen_all_modules( jit_specs.append(gen_tgv_gemm_sm10x_module(torch.float16, use_sm_100f=True)) if has_sm103: jit_specs.append(gen_fp4_quantization_sm103_module()) + jit_specs.append(gen_cutlass_fused_moe_sm103_module()) if has_sm110: jit_specs.append(gen_fp4_quantization_sm110_module()) if has_sm120: diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index 8121c99c0a..84e3ade9c7 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -21,6 +21,7 @@ convert_to_block_layout, cutlass_fused_moe, gen_cutlass_fused_moe_sm120_module, + gen_cutlass_fused_moe_sm103_module, gen_cutlass_fused_moe_sm100_module, gen_cutlass_fused_moe_sm90_module, gen_trtllm_gen_fused_moe_sm100_module, @@ -39,6 +40,7 @@ "convert_to_block_layout", "cutlass_fused_moe", "gen_cutlass_fused_moe_sm120_module", + "gen_cutlass_fused_moe_sm103_module", "gen_cutlass_fused_moe_sm100_module", "gen_cutlass_fused_moe_sm90_module", "gen_trtllm_gen_fused_moe_sm100_module", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index b4444aa431..3c5e7a09c5 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -34,6 +34,7 @@ ) from ..jit.fused_moe import ( gen_cutlass_fused_moe_sm120_module, + gen_cutlass_fused_moe_sm103_module, gen_cutlass_fused_moe_sm100_module, gen_cutlass_fused_moe_sm90_module, gen_cutlass_fused_moe_sm89_module, @@ -315,7 +316,9 @@ def convert_to_block_layout(input_tensor: torch.Tensor, blockK: int) -> torch.Te def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = False): if backend in ("120", "121"): module = gen_cutlass_fused_moe_sm120_module(use_fast_build).build_and_load() - elif backend in ("100", "103", "110"): + elif backend == "103": + module = gen_cutlass_fused_moe_sm103_module(use_fast_build).build_and_load() + elif backend in ("100", "110"): module = gen_cutlass_fused_moe_sm100_module(use_fast_build).build_and_load() elif backend == "90": module = gen_cutlass_fused_moe_sm90_module(use_fast_build).build_and_load() diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 78c19e98ac..152d92f161 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -47,6 +47,24 @@ def gen_cutlass_fused_moe_sm120_module(use_fast_build: bool = False) -> JitSpec: return gen_cutlass_fused_moe_module(nvcc_flags, "120", use_fast_build) +def gen_cutlass_fused_moe_sm103_module(use_fast_build: bool = False) -> JitSpec: + nvcc_flags = [ + "-DCOMPILE_BLACKWELL_TMA_GEMMS", + "-DCOMPILE_BLACKWELL_TMA_GROUPED_GEMMS", + "-DENABLE_BF16", + "-DENABLE_FP8", + "-DENABLE_FP4", + "-DUSING_OSS_CUTLASS_MOE_GEMM", + "-DCOMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS", + ] + + nvcc_flags += current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[10] + ) + + return gen_cutlass_fused_moe_module(nvcc_flags, "103", use_fast_build) + + def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec: nvcc_flags = [ "-DCOMPILE_BLACKWELL_TMA_GEMMS", From ba8f3ed98ab27f0821124ba2278824cdba83478a Mon Sep 17 00:00:00 2001 From: Maximilien Breughe <50598321+nvmbreughe@users.noreply.github.com> Date: Fri, 14 Nov 2025 14:12:41 -0600 Subject: [PATCH 059/130] MNNVL All Reduce for large number of tokens (#2074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This PR does two things: * Add a check for the number of tokens and raise an exception if the max token size was exceeded * Adds an optional parameter to allow users to dial in an arbitrary workspace ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Added an optional configurable workspace buffer size for all-reduce operations with a sensible default to preserve backwards compatibility. * Runtime input validation now enforces 2D inputs and token-count limits, with clearer error messages guiding corrective actions. * **Tests** * Expanded test coverage for workspace behavior: default sizing, explicit sizing, and negative tests for insufficient workspace. * Tests now allow supplying an explicit workspace size to validate allocation and reuse scenarios. --- flashinfer/comm/trtllm_mnnvl_ar.py | 18 ++- tests/comm/test_trtllm_mnnvl_allreduce.py | 134 ++++++++++++++++++---- 2 files changed, 128 insertions(+), 24 deletions(-) diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index d8d975db73..76aedee260 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -122,7 +122,7 @@ def trtllm_mnnvl_rmsnorm( def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype + mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -138,6 +138,7 @@ def get_allreduce_mnnvl_workspace( Args: mapping: Tensor parallel mapping configuration containing rank info dtype: Data type of the tensors being reduced + buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens Returns: Tuple containing: @@ -152,7 +153,9 @@ def get_allreduce_mnnvl_workspace( # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 # max_num_elements must be a multiple of 286720 lcm_hidden_dim = 286720 - TARGET_WORKSPACE_SIZE_BYTES = 12_000_000 + TARGET_WORKSPACE_SIZE_BYTES = ( + buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 + ) buffer_size_in_bytes = math.ceil( TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) ) * (lcm_hidden_dim * stride) @@ -223,6 +226,17 @@ def trtllm_mnnvl_all_reduce( [Optional] out: Output tensor to store the result (required if wait_for_results is True) """ + + if len(inp.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." + ) + + if inp.shape[0] > buffer_M: + raise ValueError( + f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." + ) + module = get_trtllm_mnnvl_comm_module() module.trtllm_mnnvl_all_reduce( inp, diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 79830065b6..abb3795019 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -147,25 +147,27 @@ def func( ) -"""Main test function that runs on each MPI rank""" +"""Helper function to run the core MNNVL AllReduce test logic""" -@pytest.mark.parametrize( - "seq_lens", - [ - [1], - [4], - [15], - [27, 11, 24], - [127], - ], -) # Test with different sequence length lists -@pytest.mark.parametrize("fusion", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) -def test_mnnvl_allreduce_full( - monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int +def run_mnnvl_ar_full( + monkeypatch, + seq_lens: list[int], + fusion: bool, + dtype: torch.dtype, + hidden_size: int, + explicit_workspace_bytes: int | None = None, ): + """Core test logic for MNNVL AllReduce operations. + + Args: + monkeypatch: pytest monkeypatch fixture + seq_lens: List of sequence lengths to test + fusion: Whether to test fused allreduce+rmsnorm or just allreduce + dtype: Data type for tensors + hidden_size: Hidden dimension size + explicit_workspace_bytes: If provided, use this workspace size instead of default + """ monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. # Get MPI info @@ -211,7 +213,9 @@ def test_mnnvl_allreduce_full( # This workspace is sized for the maximum expected sequence length and can be reused within each list # Each parameterized list gets its own fresh workspace allocation mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( - trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype) + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( + mapping, dtype, buffer_size_in_bytes=explicit_workspace_bytes + ) ) multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() @@ -291,18 +295,21 @@ def test_mnnvl_allreduce_full( rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) - # Gather failure status from all ranks + + # Gather failure status from all ranks for logging all_failures = MPI.COMM_WORLD.allgather(rank_failed) - # If any rank failed, fail the test if any(all_failures): failed_ranks = [i for i, failed in enumerate(all_failures) if failed] if rank == 0: print(f"Test failed on ranks: {failed_ranks}") - # Fail the test on all ranks - pytest.fail(f"Test failed on ranks {failed_ranks}") - trtllm_mnnvl_ar.mpi_barrier() + # Cleanup before re-raising + if "mcast_buffer_mnnvl" in locals(): + del mcast_buffer_mnnvl + + # Re-raise the original exception so it can be caught by pytest.raises in negative tests + raise finally: # Ensure cleanup happens for this list's workspace @@ -311,3 +318,86 @@ def test_mnnvl_allreduce_full( # Final synchronization and check for failures across all ranks trtllm_mnnvl_ar.mpi_barrier() + + +"""Test with default workspace size""" + + +@pytest.mark.parametrize( + "seq_lens", + [ + [1], + [4], + [15], + [27, 11, 24], + [127], + ], +) +@pytest.mark.parametrize("fusion", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) +def test_mnnvl_allreduce_default_workspace( + monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int +): + """Test MNNVL AllReduce with default workspace size.""" + run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size) + + +"""Test with explicit workspace size""" + + +@pytest.mark.parametrize( + "seq_lens", + [ + [1, 4, 180], + ], +) +@pytest.mark.parametrize("fusion", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) +def test_mnnvl_allreduce_explicit_workspace( + monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int +): + """Test MNNVL AllReduce with explicitly calculated workspace size.""" + # Calculate workspace to fit the maximum sequence length + # buffer shape: [3, 2, buffer_tokens, hidden_dim] + explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) + run_mnnvl_ar_full( + monkeypatch, + seq_lens, + fusion, + dtype, + hidden_size, + explicit_workspace_bytes=explicit_workspace_bytes, + ) + + +"""Negative test: workspace too small""" + + +@pytest.mark.parametrize("fusion", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [2048, 4096]) +def test_mnnvl_allreduce_workspace_too_small( + monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int +): + """Test that MNNVL AllReduce fails gracefully when workspace is too small.""" + # Use a large sequence length that won't fit in a small workspace + seq_len = 180 + + # Create a workspace that's too small (only enough for 10 tokens) + small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10 + + # Expect a ValueError with a message about buffer_M being too small + with pytest.raises((ValueError, RuntimeError)) as exc_info: + run_mnnvl_ar_full( + monkeypatch, + [seq_len], + fusion, + dtype, + hidden_size, + explicit_workspace_bytes=small_workspace_bytes, + ) + + # Verify the error message contains the expected text + assert "greater than the buffer_M" in str(exc_info.value) From cce4952fdd41b353325e11d99e1fc0b0737961ff Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Fri, 14 Nov 2025 22:43:40 +0100 Subject: [PATCH 060/130] perf: TRT-LLM Gen finalize kernel optimization (#2092) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description - Small optimization for TRT-LLM Gen MoE finalize kernel TopK=8, NumExperts=128, HiddenSize=4096 | BS | Baseline, us | Optimized, us | Speed-up | | ------------- | ------------- | ------------- | ------------- | | 256 | 11 | 6 | 1.83 | | 512 | 12 | 7 | 1.71 | | 1024 | 16 | 15 | 1.06 | | 4096 | 55 | 49 | 1.12 | | 8192 | 107 | 95 | 1.13 | | 16384 | 205 | 183 | 1.12 | ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Enabled vectorized, Top-K unrolled finalize path for MOE (Mixture of Experts) kernel operations with improved performance. * Added support for multiple data types (bfloat16, float, half) with enhanced type specialization and packing. * Introduced runtime validation for TopK configurations (โ‰ค 64) to ensure optimal vectorized execution. --- csrc/trtllm_fused_moe_dev_kernel.cu | 191 ++++++++++++++++-- .../flashinfer/trtllm/fused_moe/DevKernel.h | 54 +++-- 2 files changed, 202 insertions(+), 43 deletions(-) diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index 9a51384090..7a58042041 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -672,11 +672,128 @@ __device__ float4 vectorizedLoadPtx(float4 const* ptr) { // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip // connection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int MaxTopK = 64; + +typedef struct __CUDA_ALIGN__(4) { + cutlass::bfloat16_t array[2]; +} bfloat16_2; + +typedef struct __CUDA_ALIGN__(8) { + cutlass::bfloat16_t array[4]; +} bfloat16_4; + +typedef struct __CUDA_ALIGN__(8) { + half array[4]; +} half_4; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ScaleTraitsStruct; + +template <> +struct ScaleTraitsStruct<1, cutlass::bfloat16_t> { + using PackedType = cutlass::bfloat16_t; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<2, cutlass::bfloat16_t> { + using PackedType = bfloat16_2; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<4, cutlass::bfloat16_t> { + using PackedType = bfloat16_4; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<1, float> { + using PackedType = float; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<2, float> { + using PackedType = float2; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<4, float> { + using PackedType = float4; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<1, half> { + using PackedType = half; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<2, half> { + using PackedType = half2; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<4, half> { + using PackedType = half_4; + using ArrayType = cutlass::Array; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FinalizeTraits; + +template +struct FinalizeTraits<1, TypeExpW_> { + using IdxPackedType = int; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<1, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +template +struct FinalizeTraits<2, TypeExpW_> { + using IdxPackedType = int2; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<2, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +template +struct FinalizeTraits<4, TypeExpW_> { + using IdxPackedType = int4; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<4, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void finalizeKernelVecLoad(KernelParams params) { using Type = typename KernelParams::Type; using TypeExpW = typename KernelParams::TypeExpW; + int constexpr TopKUnrollFactor = KernelParams::TopKUnrollFactor; + + static_assert(TopKUnrollFactor == 1 || TopKUnrollFactor == 2 || TopKUnrollFactor == 4, + "TopKUnrollFactor must be 1, 2, or 4"); + using FinalizeTraits = FinalizeTraits; + using IdxPackedType = typename FinalizeTraits::IdxPackedType; + using IdxArrayType = typename FinalizeTraits::IdxArrayType; + using ScalePackedType = typename FinalizeTraits::ScalePackedType; + using ScaleArrayType = typename FinalizeTraits::ScaleArrayType; int const hiddenDimPaddedBits = params.hiddenDimPadded * cutlass::sizeof_bits::value; int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits::value; @@ -694,6 +811,23 @@ __global__ void finalizeKernelVecLoad(KernelParams params) { int64_t const stride = FINALIZE_THREADS_PER_BLOCK; int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD; int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD; + bool const useScale = params.expertWeightsPtr != nullptr; + + __shared__ ScalePackedType scaleArrSmem[MaxTopK / TopKUnrollFactor]; + __shared__ IdxPackedType permutedIdxArrSmem[MaxTopK / TopKUnrollFactor]; + + for (int kChunkIdx = threadIdx.x; kChunkIdx < params.topK / TopKUnrollFactor; + kChunkIdx += blockDim.x) { + int const expandedIdx = tokenIdx * params.topK + kChunkIdx * TopKUnrollFactor; + auto permutedIdxPacked = reinterpret_cast( + params.expandedIdxToPermutedIdx)[expandedIdx / TopKUnrollFactor]; + auto scalePacked = useScale ? reinterpret_cast( + params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor] + : ScalePackedType{TypeExpW(1.f)}; + + scaleArrSmem[kChunkIdx] = scalePacked; + permutedIdxArrSmem[kChunkIdx] = permutedIdxPacked; + } auto const offset = tokenIdx * params.hiddenDim; Type* outputPtr = params.outPtr + offset; @@ -706,31 +840,42 @@ __global__ void finalizeKernelVecLoad(KernelParams params) { cudaGridDependencySynchronize(); } #endif + __syncthreads(); for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride) { ComputeElem threadOutput; threadOutput.fill(0); - for (int k = 0; k < params.topK; ++k) { - int const expandedIdx = tokenIdx * params.topK + k; - int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; - if (permutedIdx == -1) { - continue; - } - - float const scale = (params.expertWeightsPtr != nullptr) - ? static_cast(params.expertWeightsPtr[expandedIdx]) - : 1.f; + for (int kChunkIdx = 0; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx++) { + auto permutedIdxArr = *reinterpret_cast(&permutedIdxArrSmem[kChunkIdx]); + InputElem inputElemArr[TopKUnrollFactor]; +#pragma unroll + for (int ki = 0; ki < TopKUnrollFactor; ++ki) { + auto const permutedIdx = permutedIdxArr[ki]; + if (permutedIdx == -1) { + continue; + } - auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol; + auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol; - float4 input = - vectorizedLoadPtx(reinterpret_cast(&inputPermutedPtr[elemIndex])); - InputElem inputPermutedElem = *reinterpret_cast(&input); - ComputeElem expertResult = arrayConvert(inputPermutedElem); + float4 input = + vectorizedLoadPtx(reinterpret_cast(&inputPermutedPtr[elemIndex])); + inputElemArr[ki] = *reinterpret_cast(&input); + } + auto scaleArr = *reinterpret_cast(&scaleArrSmem[kChunkIdx]); + auto const scaleFloatArr = + arrayConvert>(scaleArr); - threadOutput = threadOutput + scale * expertResult; +#pragma unroll + for (int ki = 0; ki < TopKUnrollFactor; ++ki) { + auto const permutedIdx = permutedIdxArr[ki]; + if (permutedIdx == -1) { + continue; + } + auto scale = useScale ? scaleFloatArr[ki] : 1.0f; + ComputeElem expertResult = arrayConvert(inputElemArr[ki]); + threadOutput = threadOutput + scale * expertResult; + } } - OutputElem outputElem = arrayConvert(threadOutput); outElemPtr[elemIndex] = outputElem; } @@ -813,7 +958,7 @@ void run(Data const& data, void* stream) { int const numBlocksY = std::min(8192, data.numTokens); dim3 numBlocks(numBlocksX, numBlocksY); - LAUNCH_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream); + LAUNCH_TOPK_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream); } else { int const numThreads = 256; int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads; @@ -827,10 +972,14 @@ void run(Data const& data, void* stream) { // ensure that when the number of waves is greater than 1, we choose to use the kernel with // vectorized loading. dim3 numBlocks(numBlocksX, numBlocksY); - LAUNCH_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream); + LAUNCH_TOPK_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream); } else { - LAUNCH_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens, - /*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream); + FLASHINFER_CHECK( + data.topK <= MaxTopK, + "Finalize kernel with vectorized loading is not supported for this TopK value: %d", + data.topK); + LAUNCH_TOPK_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens, + /*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream); } } } diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 0ee9ba6fe9..23abb87a7b 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -116,27 +116,36 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported dtypeElt"); \ } -#define LAUNCH_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported pair"); \ +#define LAUNCH_EXPW(data, kernel, topK, numBlocks, numThreads, smemSize, stream) \ + if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float, topK), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float, topK), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float, topK), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t, topK), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t, topK), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, topK), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported pair"); \ + } + +#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.topK % 4 == 0) { \ + LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \ + } else if (data.topK % 2 == 0) { \ + LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \ + } else { \ + LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ } #define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ @@ -453,10 +462,11 @@ struct Data { int32_t const* totalNumPaddedTokens; }; -template +template struct KernelParams { using Type = Type_; using TypeExpW = TypeExpW_; + static constexpr int TopKUnrollFactor = TopKUnrollFactor_; static constexpr bool UsePdl = UsePdl_; Type const* inPtr; From 4ddf71defe7260fd4a677fdf8ee3ecd48784a0f4 Mon Sep 17 00:00:00 2001 From: eigen <52445717+yyihuang@users.noreply.github.com> Date: Sun, 16 Nov 2025 01:01:17 -0500 Subject: [PATCH 061/130] refactor: update dpsk fused_moe test [1] (#2088) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Refactor fused_moe test. Split test on model+precision. Part [1]: - test deepseek (kimi, lite) fp8 block-scaled fused moe - default TP8 - PDL enabled - MajorK weight layout - higher tolerance and matching percentage Next Part [2]: - add BlockMajorK weight layout Next Part [x]: - Per Tensor FP8 MoE, FP4MoE later: - refactor llama4, topk?, renormalize? routing tests ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Added a comprehensive FP8 block-scale fused Mixture-of-Experts test validating end-to-end correctness across many routing, expert and precision configurations. Includes randomized inputs, per-token/per-expert workflows, extensive parameterizations, diagnostic statistics, autotune-path checks, and a minimal sanity run. --- tests/moe/test_dpsk_fused_moe_fp8.py | 570 +++++++++++++++++++++++++++ 1 file changed, 570 insertions(+) create mode 100644 tests/moe/test_dpsk_fused_moe_fp8.py diff --git a/tests/moe/test_dpsk_fused_moe_fp8.py b/tests/moe/test_dpsk_fused_moe_fp8.py new file mode 100644 index 0000000000..3ac4055128 --- /dev/null +++ b/tests/moe/test_dpsk_fused_moe_fp8.py @@ -0,0 +1,570 @@ +import pytest +import torch +from flashinfer.fused_moe import trtllm_fp8_block_scale_moe, WeightLayout +from flashinfer.autotuner import autotune + + +def run( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + hidden_states: torch.Tensor, + hidden_states_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + local_expert_offset: int, + routed_scaling_factor: float, + hidden_size: int, + intermediate_size: int, + num_experts_global: int, + num_local_experts: int, + top_k: int, + n_group: int, + topk_group: int, +): + """ + - FP8 block-scale dequantization: float โ‰ˆ fp8 * scale + - DeepSeek-V3 no-aux routing: + s = sigmoid(logits) + s_with_bias = s + bias + group by n_group=8; per group take top-2 sum โ†’ pick topk_group=4 groups + on the kept groups, take global top_k=8 experts + combine with weights derived from s (without bias), normalized and + scaled by routed_scaling_factor + - Local computation: + only experts in [local_expert_offset, local_expert_offset + E_local) are + computed on this rank (GEMM1 โ†’ SwiGLU โ†’ GEMM2), then per-token weighted + accumulation. + """ + + # Fixed DeepSeek-V3/R1 geometry + H = hidden_size # deepseek v3: 7168 + I = intermediate_size # deepseek v3: 2048 + E_local = gemm1_weights.shape[0] + + BLOCK = 128 + E_global = routing_logits.shape[1] + T = routing_logits.shape[0] + + assert E_global == num_experts_global, "num_experts_global shape mismatch" + assert E_local == num_local_experts, "num_local_experts shape mismatch" + + # Routing constants + TOP_K = top_k # deepseek v3: 8 + N_GROUP = n_group # deepseek v3: 8 + TOPK_GROUP = topk_group # deepseek v3: 4 + + # Block counts + num_hidden_blocks = H // BLOCK # 56 + num_intermediate_blocks = I // BLOCK # 16 + num_gemm1_out_blocks = (2 * I) // BLOCK # 32 + + # Shape checks + assert hidden_states.shape == (T, H) + assert hidden_states_scale.shape == (num_hidden_blocks, T) + assert gemm1_weights.shape == (E_local, 2 * I, H) + assert gemm1_weights_scale.shape == ( + E_local, + num_gemm1_out_blocks, + num_hidden_blocks, + ) + assert gemm2_weights.shape == (E_local, H, I) + assert gemm2_weights_scale.shape == ( + E_local, + num_hidden_blocks, + num_intermediate_blocks, + ) + assert routing_bias.shape[-1] == E_global + + device = hidden_states.device + + # 1) FP8 block-scale dequantization + # hidden_states: [T, H], scale: [H/128, T] (transposed layout) + A_fp32 = hidden_states.to(torch.float32) + A_scale = hidden_states_scale.to(torch.float32) # [H/128, T] + A_scale_TH = A_scale.permute(1, 0).contiguous() # [T, H/128] + A_scale_expanded = ( + A_scale_TH.unsqueeze(-1) + .repeat(1, 1, BLOCK) # [T, H/128, 128] + .reshape(T, H) # [T, H] + .contiguous() + ) + A = A_fp32 * A_scale_expanded # [T, H] float32 + + # W13: [E_local, 2I, H], scale: [E_local, (2I)/128, H/128] + W13_fp32 = gemm1_weights.to(torch.float32) + S13 = gemm1_weights_scale.to(torch.float32) + S13_expanded = torch.repeat_interleave(S13, BLOCK, dim=1) # [E, 2I, H/128] + S13_expanded = torch.repeat_interleave(S13_expanded, BLOCK, dim=2) # [E, 2I, H] + W13 = W13_fp32 * S13_expanded # [E, 2I, H] float32 + + # W2: [E_local, H, I], scale: [E_local, H/128, I/128] + W2_fp32 = gemm2_weights.to(torch.float32) + S2 = gemm2_weights_scale.to(torch.float32) + S2_expanded = torch.repeat_interleave(S2, BLOCK, dim=1) # [E, H, I/128] + S2_expanded = torch.repeat_interleave(S2_expanded, BLOCK, dim=2) # [E, H, I] + W2 = W2_fp32 * S2_expanded # [E, H, I] float32 + + # 2) No-aux routing + logits = routing_logits.to(torch.float32) # [T, E_global] + bias = routing_bias.to(torch.float32).reshape(-1) # [E_global] + + # Sigmoid + s = 1.0 / (1.0 + torch.exp(-logits)) # [T, E] + s_with_bias = s + bias # [T, E] (broadcast) + + # Grouping + group_size = E_global // N_GROUP # 32 + s_wb_grouped = s_with_bias.view(T, N_GROUP, group_size) # [T, 8, 32] + + # Group scores = sum of top-2 values within each group + top2_vals, _ = torch.topk( + s_wb_grouped, k=2, dim=2, largest=True, sorted=False + ) # [T, 8, 2] + group_scores = top2_vals.sum(dim=2) # [T, 8] + + # Select topk_group groups โ†’ group mask + _, group_idx = torch.topk( + group_scores, k=TOPK_GROUP, dim=1, largest=True, sorted=False + ) # [T, 4] + group_mask = torch.zeros_like(group_scores) # [T, 8] + group_mask.scatter_(1, group_idx, 1.0) + score_mask = ( + group_mask.unsqueeze(2).expand(T, N_GROUP, group_size).reshape(T, E_global) + ) # [T, E] + + # Global top-k (within kept groups), based on s_with_bias + neg_inf = torch.finfo(torch.float32).min + scores_pruned = s_with_bias.masked_fill(score_mask == 0, neg_inf) # [T, E] + _, topk_idx = torch.topk( + scores_pruned, k=TOP_K, dim=1, largest=True, sorted=False + ) # [T, 8] + + # Combination weights: use s (without bias) for normalization + M = torch.zeros_like(s) # [T, E] + M.scatter_(1, topk_idx, 1.0) # 0/1 mask + weights = s * M # [T, E] + weights_sum = weights.sum(dim=1, keepdim=True) + 1e-20 + weights = (weights / weights_sum) * routed_scaling_factor # [T, E] + + # 3) Local expert compute and accumulation + output = torch.zeros((T, H), dtype=torch.float32, device=device) + + local_start = int(local_expert_offset) + + # For each local expert: find selected tokens, run GEMM1โ†’SwiGLUโ†’GEMM2, accumulate by weights + for le in range(E_local): + ge = local_start + le + if ge < 0 or ge >= E_global: + continue + + # Tokens that selected this global expert ge in their top-k + sel_mask_per_token = (topk_idx == ge).any(dim=1) # [T] bool + if not sel_mask_per_token.any(): + continue + + token_idx = torch.nonzero(sel_mask_per_token, as_tuple=False).squeeze(1) # [Tk] + + # Gather inputs and weights for this expert + A_e = A.index_select(0, token_idx) # [Tk, H] + W13_e = W13[le] # [2I, H] + W2_e = W2[le] # [H, I] + + # GEMM1: [Tk, H] @ [H, 2I] = [Tk, 2I] + G1 = A_e.matmul(W13_e.t()) # [Tk, 2I] + + # SwiGLU: split and apply silu(x) = x / (1 + exp(-x)) + X1 = G1[:, :I] # [Tk, I] + X2 = G1[:, I:] # [Tk, I] + silu_X2 = X2 / (1.0 + torch.exp(-X2)) # [Tk, I] + C = silu_X2 * X1 # [Tk, I] + + # GEMM2: [Tk, I] @ [I, H] = [Tk, H] + O = C.matmul(W2_e.t()) # [Tk, H] + + # Accumulate with per-token routing weights for this expert + w_tok = weights.index_select(0, token_idx)[:, ge] # [Tk] + output.index_add_(0, token_idx, O * w_tok.unsqueeze(1)) # [Tk,H] * [Tk,1] + + return output.to(torch.bfloat16) + + +# ----------------------------- +# Helpers: FP8 block quantization (dequant scale semantics) +# ----------------------------- +def _fp8_block_quant_1d(x_bf16: torch.Tensor, block: int = 128): + """ + Quantize [T, H] activations into FP8 with per-(token, 128-col) block scales. + Returns: + x_fp8: [T, H] (float8_e4m3fn) + scales_TxNb: [T, H/128] (float32) -- dequant scales (float โ‰ˆ fp8 * scale) + """ + assert x_bf16.dim() == 2 + T, H = x_bf16.shape + assert H % block == 0 + nb = H // block + + finfo = torch.finfo(torch.float8_e4m3fn) + max_fp8 = finfo.max + + x_f32 = x_bf16.to(torch.float32) + x_fp8 = torch.empty((T, H), dtype=torch.float8_e4m3fn, device=x_bf16.device) + scales = torch.empty((T, nb), dtype=torch.float32, device=x_bf16.device) + + for j in range(nb): + sl = slice(j * block, (j + 1) * block) + blk = x_f32[:, sl] # [T, 128] + amax = torch.amax(torch.abs(blk), dim=1) # [T] + # dequant scale s = amax / max_fp8 (float โ‰ˆ fp8 * s) + s = torch.where(amax > 0, amax / max_fp8, torch.ones_like(amax)) + q = (blk / s.unsqueeze(1)).to(torch.float8_e4m3fn) # quantization + x_fp8[:, sl] = q + scales[:, j] = s + return x_fp8, scales # scales in [T, H/128] + + +def _fp8_block_quant_2d(w_bf16: torch.Tensor, block: int = 128): + """ + Quantize weights with 2D block scales over the last two dims. + w_bf16: [*, R, C] (R and C are multiples of 128) + Returns: + w_fp8: [*, R, C] (float8_e4m3fn) + scales: [*, R/128, C/128] (float32) -- dequant scales + """ + assert w_bf16.dim() >= 2 + *prefix, R, C = w_bf16.shape + assert R % block == 0 and C % block == 0 + nb_r = R // block + nb_c = C // block + + finfo = torch.finfo(torch.float8_e4m3fn) + max_fp8 = finfo.max + + w_f32 = w_bf16.to(torch.float32).contiguous() + prefix_ndim = len(prefix) + + # Reshape weights into 128x128 blocks and move block dims to the tail: + # [..., nb_r, block, nb_c, block] -> [..., nb_r, nb_c, block, block] + reshaped = w_f32.reshape(*prefix, nb_r, block, nb_c, block) + permute_dims = tuple(range(prefix_ndim)) + ( + prefix_ndim, + prefix_ndim + 2, + prefix_ndim + 1, + prefix_ndim + 3, + ) + blocks = reshaped.permute(permute_dims).contiguous() + + # Compute per-block scales + amax = torch.amax(torch.abs(blocks), dim=(-1, -2)) + scales = torch.where( + amax > 0, + amax / max_fp8, + torch.ones_like(amax, dtype=torch.float32), + ) + + # Quantize blocks in parallel + q_blocks = (blocks / scales.unsqueeze(-1).unsqueeze(-1)).to(torch.float8_e4m3fn) + + # Restore original layout + inv_permute = [0] * (prefix_ndim + 4) + for i, d in enumerate(permute_dims): + inv_permute[d] = i + w_fp8 = q_blocks.permute(*inv_permute).reshape(*prefix, R, C) + + return w_fp8, scales + + +# ----------------------------- +# Random input generator for MoE DS-V3 +# ----------------------------- +def generate_random_inputs_moe( + seq_len: int, + *, + num_experts_global: int = 256, + num_local_experts: int = 32, + hidden_size: int = 7168, + intermediate_size: int = 2048, + use_bias: bool = True, + local_expert_offset: int = 0, + routed_scaling_factor: float = 2.5, + device: str = "cuda", +): + assert hidden_size % 128 == 0 and intermediate_size % 128 == 0 + T, H, I = seq_len, hidden_size, intermediate_size + E_global, E_local = num_experts_global, num_local_experts + + # Inputs for routing + routing_logits = torch.randn(T, E_global, dtype=torch.float32, device=device) + if use_bias: + routing_bias = torch.randn(E_global, dtype=torch.bfloat16, device=device) + else: + routing_bias = torch.zeros(E_global, dtype=torch.bfloat16, device=device) + + # Activations: start from bf16, then FP8 block-quant with dequant scales + a_bf16 = 2.0 * torch.randn(T, H, dtype=torch.bfloat16, device=device) + a_fp8, a_scales_TxNb = _fp8_block_quant_1d(a_bf16, block=128) # scales: [T, H/128] + hidden_states = a_fp8 + hidden_states_scale = a_scales_TxNb.transpose(0, 1).contiguous() # [H/128, T] + + # Weights per local expert + # W13: [E_local, 2I, H], W2: [E_local, H, I] + w13_bf16 = torch.randn(E_local, 2 * I, H, dtype=torch.bfloat16, device=device) + w2_bf16 = torch.randn(E_local, H, I, dtype=torch.bfloat16, device=device) + + w13_fp8, w13_scales = _fp8_block_quant_2d( + w13_bf16, block=128 + ) # scales: [E, (2I)/128, H/128] + w2_fp8, w2_scales = _fp8_block_quant_2d( + w2_bf16, block=128 + ) # scales: [E, H/128, I/128] + + return { + "routing_logits": routing_logits, + "routing_bias": routing_bias, + "hidden_states": hidden_states, + "hidden_states_scale": hidden_states_scale, + "gemm1_weights": w13_fp8, + "gemm1_weights_scale": w13_scales, + "gemm2_weights": w2_fp8, + "gemm2_weights_scale": w2_scales, + "local_expert_offset": int(local_expert_offset), + "local_num_experts": E_local, + "routed_scaling_factor": float(routed_scaling_factor), + } + + +# Max num tokens to tune for trtllm-gen fused moe +TUNE_MAX_NUM_TOKENS = 4096 + + +# ----------------------------- +# Test Entry +# ----------------------------- +@pytest.mark.parametrize( + "seq_len, local_expert_offset, use_bias", + [ + (1, 0, False), + (4, 0, True), + (8, 64, True), + (16, 32, True), + (64, 128, True), + (256, 64, True), + (1024, 32, True), + ], +) +@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 384, + "top_k": 8, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "compatible_intermediate_size": [1024, 2048], + "enable_autotune": True, + }, + id="kimi_k2", + ), + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "compatible_intermediate_size": [512, 1024, 2048], + "enable_autotune": True, + }, + id="DSv3", + ), + pytest.param( + { + "num_experts": 72, + "top_k": 6, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "compatible_intermediate_size": [384, 768], + "enable_autotune": False, + }, + id="DSLite", + ), + ], +) +@pytest.mark.parametrize("enable_pdl", [True, False]) +def test_correctness_dpsk_fp8_fused_moe( + seq_len, + local_expert_offset, + use_bias, + intermediate_size, + routing_config, + enable_pdl, + atol: float = 1e-1, + rtol: float = 2e-1, + percent: float = 0.85, +): + compatible_intermediate_size = routing_config["compatible_intermediate_size"] + if intermediate_size not in compatible_intermediate_size: + pytest.skip( + f"Intermediate size {intermediate_size} is not compatible with routing config {routing_config}" + ) + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + if trtllm_fp8_block_scale_moe is None: + pytest.skip("flashinfer fused_moe kernel not available") + + device = "cuda" + torch.manual_seed(42) + + # Constants (DeepSeek-V3) + E_GLOBAL = routing_config["num_experts"] # deepseek v3: 256 + E_LOCAL = 32 # todo(yingyi): default to tp8 for now, update later + H = 7168 + I = intermediate_size # deepseek v3: 2048 + TOP_K = routing_config["top_k"] # deepseek v3: 8 + N_GROUP = routing_config["n_groups"] # deepseek v3: 8 + TOPK_GROUP = routing_config["top_k_groups"] # deepseek v3: 4 + + if local_expert_offset + E_LOCAL > E_GLOBAL: + pytest.skip( + f"Local expert offset {local_expert_offset} + {E_LOCAL} is greater than number of experts {E_GLOBAL}" + ) + + # Generate random but consistent inputs + inputs = generate_random_inputs_moe( + seq_len, + num_experts_global=E_GLOBAL, + num_local_experts=E_LOCAL, + hidden_size=H, + intermediate_size=I, + use_bias=use_bias, + local_expert_offset=local_expert_offset, + routed_scaling_factor=routing_config["routed_scaling"], + device=device, + ) + + # Run reference (returns bf16) + ref_out = run( + routing_logits=inputs["routing_logits"], + routing_bias=inputs["routing_bias"], + hidden_states=inputs["hidden_states"], + hidden_states_scale=inputs["hidden_states_scale"], + gemm1_weights=inputs["gemm1_weights"], + gemm1_weights_scale=inputs["gemm1_weights_scale"], + gemm2_weights=inputs["gemm2_weights"], + gemm2_weights_scale=inputs["gemm2_weights_scale"], + local_expert_offset=inputs["local_expert_offset"], + routed_scaling_factor=inputs["routed_scaling_factor"], + hidden_size=H, + intermediate_size=I, + num_experts_global=E_GLOBAL, + num_local_experts=E_LOCAL, + top_k=TOP_K, + n_group=N_GROUP, + topk_group=TOPK_GROUP, + ) + + # Run FlashInfer fused kernel + with autotune(routing_config["enable_autotune"]): + fi_out = trtllm_fp8_block_scale_moe( + inputs["routing_logits"].to(torch.float32), + inputs["routing_bias"], # bf16 + inputs["hidden_states"], # fp8 + inputs["hidden_states_scale"], # [H/128, T] + inputs["gemm1_weights"], # fp8 + inputs["gemm1_weights_scale"].to(torch.float32), + inputs["gemm2_weights"], # fp8 + inputs["gemm2_weights_scale"].to(torch.float32), + E_GLOBAL, + TOP_K, + N_GROUP, + TOPK_GROUP, + I, + inputs["local_expert_offset"], + inputs["local_num_experts"], + inputs["routed_scaling_factor"], + routing_method_type=2, # DeepSeek-styled + use_shuffled_weight=False, + weight_layout=WeightLayout.MajorK.value, + enable_pdl=enable_pdl, + tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + ) + + # Compare + ref_f32 = ref_out.float() + fi_f32 = fi_out.float() + + abs_diff = (ref_f32 - fi_f32).abs() + rel_diff = abs_diff / (fi_f32.abs() + 1e-8) + + print("\nComparison stats:") + print(f"Max abs diff: {abs_diff.max().item():.6e}") + print(f"Mean abs diff: {abs_diff.mean().item():.6e}") + print(f"Max rel diff: {rel_diff.max().item():.6e}") + print(f"Mean rel diff: {rel_diff.mean().item():.6e}") + + # Cosine similarity and MSE + cos_sim = torch.nn.functional.cosine_similarity( + ref_f32.flatten(), fi_f32.flatten(), dim=0 + ).item() + mse = torch.mean((ref_f32 - fi_f32) ** 2).item() + print(f"Cosine similarity: {cos_sim:.6f}") + print(f"MSE: {mse:.6e}") + + # Strict allclose + allclose = torch.allclose(ref_f32, fi_f32, atol=atol, rtol=rtol) + print(f"\nAllclose(atol={atol}, rtol={rtol}): {allclose}") + + if not allclose: + # Show top-5 largest absolute errors + flat = abs_diff.flatten() + k = min(5, flat.numel()) + topv, topi = torch.topk(flat, k) + print("\nTop-5 absolute error locations:") + for rank in range(k): + idx = topi[rank].item() + t = idx // H + h = idx % H + print( + f" [t={t}, h={h}]: ref={ref_f32.flatten()[idx].item():.6e}, " + f"fi={fi_f32.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}" + ) + + left = (ref_f32 - fi_f32).abs() + right = atol + rtol * fi_f32.abs() + ok = left <= right + hit_ratio = ok.float().mean().item() + print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {percent * 100:.2f}%)") + + assert hit_ratio >= percent, ( + f"Hit ratio {hit_ratio * 100:.2f}% is less than required {percent * 100:.2f}%" + ) + + +if __name__ == "__main__": + test_correctness_dpsk_fp8_fused_moe( + seq_len=1, + local_expert_offset=0, + use_bias=False, + intermediate_size=2048, + routing_config={ + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "compatible_intermediate_size": [512, 1024, 2048], + "enable_autotune": True, + }, + enable_pdl=True, + ) From d42b71f589e95adb848e6060129df99a66f96941 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 16 Nov 2025 17:39:22 -0500 Subject: [PATCH 062/130] chore: update thor cuda arch (from 110f to 110a) (#2096) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Duplicate of #2091, created PR from flashinfer-ai to enable workflow. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Corrected CUDA compute capability targeting from 11.0f to 11.0a for improved compatibility across build configurations. * **Documentation** * Updated installation and build documentation to reflect updated CUDA architecture configurations for both older and newer CUDA versions. --- .github/workflows/nightly-release.yml | 2 +- .github/workflows/release.yml | 2 +- README.md | 2 +- docs/installation.rst | 2 +- scripts/task_test_jit_cache_package_build_import.sh | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 2e7230cfa5..7c57d4bd7a 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -145,7 +145,7 @@ jobs: - name: Build wheel in container env: DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }} - FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0f 12.0f' }} + FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0f' }} FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }} run: | # Extract CUDA major and minor versions diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0c95611c50..b11e72e1f7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -182,7 +182,7 @@ jobs: - name: Build wheel in container env: DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }} - FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0f 12.0f' }} + FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda < '13.0' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0f' }} run: | # Extract CUDA major and minor versions CUDA_MAJOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f1) diff --git a/README.md b/README.md index 88b579b180..81b8583242 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ python -m pip install dist/*.whl `flashinfer-jit-cache` (customize `FLASHINFER_CUDA_ARCH_LIST` for your target GPUs): ```bash -export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0f 12.0f" +export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0a 12.0f" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl diff --git a/docs/installation.rst b/docs/installation.rst index 9087e87471..eb2f1acf67 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -92,7 +92,7 @@ You can follow the steps below to install FlashInfer from source code: .. code-block:: bash - export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0f 12.0f" + export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0a 12.0f" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl diff --git a/scripts/task_test_jit_cache_package_build_import.sh b/scripts/task_test_jit_cache_package_build_import.sh index d03937bc47..0627d7b82d 100755 --- a/scripts/task_test_jit_cache_package_build_import.sh +++ b/scripts/task_test_jit_cache_package_build_import.sh @@ -46,7 +46,7 @@ if cuda_ver is not None: if (major, minor) >= (13, 0): arches.append("10.0a") arches.append("10.3a") - arches.append("11.0f") + arches.append("11.0a") arches.append("12.0f") elif (major, minor) >= (12, 9): arches.append("10.0a") From 4aed50cfa663708fab3931b2d5fcab6adc5e1f55 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 16 Nov 2025 17:40:45 -0500 Subject: [PATCH 063/130] perf: enable pdl for cutlass fp4 gemm (#2095) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description The `enablePDL` flag is set to false, this PR turned them on. Set to true for both because sm_100 and sm_120 should have support of pdl. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Refactor** * Updated runtime configuration for FP4 GEMM operations to enhance execution performance on SM100 and SM120 GPU architectures. --- include/flashinfer/gemm/fp4_gemm_template_sm100.h | 2 +- include/flashinfer/gemm/fp4_gemm_template_sm120.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/gemm/fp4_gemm_template_sm100.h b/include/flashinfer/gemm/fp4_gemm_template_sm100.h index 3fa40ff9bd..5152e6e296 100644 --- a/include/flashinfer/gemm/fp4_gemm_template_sm100.h +++ b/include/flashinfer/gemm/fp4_gemm_template_sm100.h @@ -273,7 +273,7 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void std::string(cutlassGetStatusString(initStatus)); \ throw std::runtime_error("[FP4 gemm Runner] " + errMsg); \ } \ - auto runStatus = gemm.run(args, workspace, stream, nullptr, /* enablePDL */ false); \ + auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/true); \ if (runStatus != cutlass::Status::kSuccess) { \ std::string errMsg = "Failed to run cutlass FP4 gemm on sm100. Error: " + \ std::string(cutlassGetStatusString(runStatus)); \ diff --git a/include/flashinfer/gemm/fp4_gemm_template_sm120.h b/include/flashinfer/gemm/fp4_gemm_template_sm120.h index 93b082bf5d..7333d81743 100644 --- a/include/flashinfer/gemm/fp4_gemm_template_sm120.h +++ b/include/flashinfer/gemm/fp4_gemm_template_sm120.h @@ -257,7 +257,7 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void std::string(cutlass::cutlassGetStatusString(initStatus)); \ throw std::runtime_error("[FP4 gemm Runner] " + errMsg); \ } \ - auto runStatus = gemm.run(args, workspace, stream, nullptr, /* enablePDL */ false); \ + auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/true); \ if (runStatus != cutlass::Status::kSuccess) { \ std::string errMsg = "Failed to run cutlass FP4 gemm on sm120. Error: " + \ std::string(cutlass::cutlassGetStatusString(runStatus)); \ From 0a36050f2ed768bc6c65a587bf424020183fe1a8 Mon Sep 17 00:00:00 2001 From: FlashInfer Bot Date: Mon, 17 Nov 2025 23:00:48 -0800 Subject: [PATCH 064/130] chore: Update CODEOWNERS (#2098) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR updates the CODEOWNERS file based on git commit history analysis from the last 180 days. ## Changes - Updated `.github/CODEOWNERS` with current code ownership based on: - Commit frequency - File coverage - Commit recency ## How to Review 1. Review the changes to `.github/CODEOWNERS` 2. Verify that the assigned owners are appropriate for each module 3. Make manual adjustments if needed before merging ## Notes - This is an automated PR generated weekly - Minimum commits threshold: 1 - Analysis period: 180 days - Directory depth: 3 levels - Top N owners per module: 5 --- ๐Ÿค– This PR was automatically generated by the [update-codeowners workflow](.github/workflows/update-codeowners.yml) ## Summary by CodeRabbit ## Release Notes * **Chores** * Internal maintenance updates to code ownership mappings. --- **Note:** This release contains no user-facing changes. Co-authored-by: flashinfer-bot Co-authored-by: Claude --- .github/CODEOWNERS | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 24f6838702..fc3b20c491 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,8 +3,8 @@ # Analysis period: 180 days # Minimum commits threshold: 1 -benchmarks/ @bkryu @cyx-6 @yzh119 @jiahanc @nv-yunzheq -benchmarks/routines/ @bkryu @nv-yunzheq @cyx-6 @nvmbreughe @Anerudhan +benchmarks/ @bkryu @jiahanc @cyx-6 @yzh119 @nv-yunzheq +benchmarks/routines/ @bkryu @nv-yunzheq @jiahanc @cyx-6 @nvmbreughe ci/ @cyx-6 @yzh119 @nvmbreughe ci/scripts/ @cyx-6 ci/scripts/jenkins/ @cyx-6 @@ -17,7 +17,7 @@ csrc/nv_internal/include/ @wenscarl @nv-yunzheq csrc/nv_internal/tensorrt_llm/ @wenscarl @djmmoss @nv-yunzheq @yongwww @cyx-6 csrc/xqa/ @cyx-6 @yzh119 docs/ @yzh119 @cyx-6 @wenscarl @nv-yunzheq @aleozlx -flashinfer/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @jiahanc +flashinfer/ @yzh119 @cyx-6 @nvmbreughe @aleozlx @wenscarl flashinfer-cubin/ @yzh119 @cyx-6 flashinfer-cubin/flashinfer_cubin/ @yzh119 flashinfer-jit-cache/ @yzh119 @cyx-6 @@ -26,20 +26,20 @@ flashinfer/comm/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @djmmoss flashinfer/cudnn/ @Anerudhan @yzh119 @cyx-6 @Anerudhan flashinfer/cute_dsl/ @yzh119 @kaixih @Amir-19 @aleozlx flashinfer/dsv3_ops/ @nvmbreughe -flashinfer/fused_moe/ @djmmoss @yzh119 @cyx-6 @jiahanc @wenscarl +flashinfer/fused_moe/ @djmmoss @jiahanc @yzh119 @cyx-6 @aleozlx flashinfer/gemm/ @nvmbreughe -flashinfer/jit/ @yzh119 @cyx-6 @jiahanc @nvmbreughe @nv-yunzheq +flashinfer/jit/ @yzh119 @cyx-6 @aleozlx @jiahanc @nvmbreughe flashinfer/jit/attention/ @yzh119 @cyx-6 @Anerudhan @joker-eph flashinfer/jit/gemm/ @yzh119 @nv-yunzheq @jiahanc flashinfer/logits_processor/ @cyx-6 @yzh119 flashinfer/profiler/ @cyx-6 flashinfer/triton/ @nvmbreughe @cyx-6 flashinfer/tuning_configs/ @kaixih -include/ @yzh119 @jiahanc @nvmbreughe @bkryu @wenscarl -include/flashinfer/ @yzh119 @jiahanc @nvmbreughe @bkryu @wenscarl +include/ @yzh119 @jiahanc @nvmbreughe @IwakuraRein @bkryu +include/flashinfer/ @yzh119 @jiahanc @nvmbreughe @IwakuraRein @bkryu include/flashinfer/attention/ @yzh119 @kahyunnam @joker-eph include/flashinfer/comm/ @yongwww @nvmbreughe @djmmoss @yzh119 @cyx-6 -include/flashinfer/gemm/ @ttyio @yongwww @nvmbreughe @aleozlx -include/flashinfer/trtllm/ @jiahanc @joker-eph @aleozlx @yzh119 @cyx-6 +include/flashinfer/gemm/ @ttyio @yongwww @yzh119 @nvmbreughe @aleozlx +include/flashinfer/trtllm/ @jiahanc @joker-eph @aleozlx @yzh119 @wenscarl profiler/ @cyx-6 scripts/ @yzh119 @nvmbreughe @dierksen @yongwww @bkryu From 3b072473b7a59998f148fe7a636a450207045872 Mon Sep 17 00:00:00 2001 From: kahyun <69875166+kahyunnam@users.noreply.github.com> Date: Mon, 17 Nov 2025 23:15:13 -0800 Subject: [PATCH 065/130] feat: Add flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache (fused RoPE + Q + KV cache, supports MLA/GQA/MHA) (#2037) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Add `flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache`, which runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation kernel. Note that this does not support optional quantization (there is no "RoPE + append KV Cache" fused operation available). Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for MLA/MHA/GQA problem sizes for decode and prefill cases. ## ๐Ÿ” Related Issues "[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused kernels for MLA/GQA/MHA" item from Q4 roadmap: https://github.com/flashinfer-ai/flashinfer/issues/1770. This PR is part 2 to earlier PR for RoPE + Q: https://github.com/flashinfer-ai/flashinfer/pull/1924 FW Stakeholders: @nvpohanh @pavanimajety ## ๐Ÿงช Test results ``` $ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s ======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 384 items tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................ ======================================================== 384 passed in 35.22s ======================================================== ``` ``` $ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s ======================================================== test session starts ========================================================= platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 1248 items tests/attention/test_rope.py ......................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ....................................................................... ================================================== 1248 passed in 63.07s (0:01:03) =================================================== ``` ``` $ python benchmarks/bench_rope_quantize_fp8_append_cache.py Detected GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s ==================================================================================================== MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00258 86.53 1.1 0.010 32 0.00381 1873.82 23.6 0.208 128 0.00763 3744.50 47.2 0.416 384 0.01848 4637.34 58.5 0.515 768 0.03694 4639.75 58.5 0.515 1024 0.04879 4683.57 59.1 0.520 2048 0.09590 4766.09 60.1 0.529 4096 0.19031 4803.27 60.6 0.533 8192 0.38523 4745.78 59.9 0.527 ==================================================================================================== GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00294 6.36 0.1 0.003 32 0.00316 189.48 2.4 0.078 128 0.00317 755.23 9.5 0.310 384 0.00398 1803.09 22.7 0.741 768 0.00522 2750.51 34.7 1.130 1024 0.00617 3100.80 39.1 1.274 2048 0.00927 4130.83 52.1 1.697 4096 0.01631 4695.01 59.2 1.929 8192 0.03466 4418.01 55.7 1.815 ==================================================================================================== MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00293 12.68 0.2 0.004 32 0.00313 379.98 4.8 0.126 128 0.00357 1331.80 16.8 0.441 384 0.00517 2756.73 34.8 0.912 768 0.00742 3840.41 48.4 1.271 1024 0.00887 4287.15 54.1 1.419 2048 0.01504 5055.18 63.8 1.673 4096 0.03343 4548.12 57.4 1.505 8192 0.06410 4744.76 59.8 1.571 ==================================================================================================== Configuration details: Page size: 32, Batch size: 4 Token range: 1 (single decode) โ†’ 8192 (large prefill) GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100 ==================================================================================================== ``` ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA, GQA/MHA) with layout, page-size, interleave and PDL options; returns quantized Q outputs and writes K/V into paged caches; public ops and high-level API added. * **Tests** * Deterministic, parameterized tests for append and decode/continuation across attention types, layouts, dtypes and quant settings with reference validation. * **Benchmarks** * New benchmark script for performance, bandwidth and Nsight profiling of the paged-KV quantize+append path. * **Chores** * Added cached GPU memory-bandwidth utility for benchmarks. --------- Co-authored-by: Zihao Ye --- .../bench_rope_quantize_fp8_append_cache.py | 342 +++++++ csrc/flashinfer_rope_binding.cu | 10 + csrc/rope.cu | 195 ++++ flashinfer/rope.py | 364 ++++++- flashinfer/utils.py | 41 + include/flashinfer/pos_enc.cuh | 468 ++++++++- include/flashinfer/utils.cuh | 46 + tests/attention/test_rope.py | 895 ++++++++++++++++++ 8 files changed, 2325 insertions(+), 36 deletions(-) create mode 100644 benchmarks/bench_rope_quantize_fp8_append_cache.py diff --git a/benchmarks/bench_rope_quantize_fp8_append_cache.py b/benchmarks/bench_rope_quantize_fp8_append_cache.py new file mode 100644 index 0000000000..3119b9fef8 --- /dev/null +++ b/benchmarks/bench_rope_quantize_fp8_append_cache.py @@ -0,0 +1,342 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import sys +import argparse +import flashinfer +import numpy as np +import torch +from flashinfer.testing.utils import bench_gpu_time_with_cudagraph +from flashinfer.utils import get_gpu_memory_bandwidth + +# Add the project root to Python path to import test helpers +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from tests.test_helpers.rope_reference import RotaryEmbedding + + +def benchmark_config( + config_name, + num_tokens, + batch_size=4, + page_size=16, + enable_pdl=False, + single_run=False, +): + """Benchmark a specific attention configuration with paged KV cache append.""" + input_dtype = torch.bfloat16 + device = "cuda" + quant_dtype = torch.float8_e4m3fn + + # Configuration-specific parameters + if config_name == "mla": + # MLA: DeepSeek-style multi-latent attention + num_qo_heads, num_kv_heads = 128, 1 + rope_dim, no_rope_dim = 64, 512 + elif config_name == "gqa": + # GQA: Grouped-query attention (e.g., Llama-style) + num_qo_heads, num_kv_heads = 32, 8 + rope_dim, no_rope_dim = 64, 64 + elif config_name == "mha": + # MHA: Standard multi-head attention + num_qo_heads, num_kv_heads = 32, 32 + rope_dim, no_rope_dim = 64, 64 + else: + raise ValueError(f"Unknown config: {config_name}") + + head_dim = rope_dim + no_rope_dim + + # Create input tensors + if config_name == "mla": + # MLA: 2D K tensors (shared) + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn(num_tokens, rope_dim, dtype=input_dtype, device=device) + k_nope = torch.randn(num_tokens, no_rope_dim, dtype=input_dtype, device=device) + v = None + else: + # GQA/MHA: 3D K/V tensors + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn( + num_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope = torch.randn( + num_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + v = torch.randn( + num_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # Create RoPE reference for cos/sin cache (ensure it covers this run) + max_seq_len = int(num_tokens) + rope_ref = RotaryEmbedding( + head_size=head_dim, + rotary_dim=rope_dim, + max_position_embeddings=max_seq_len, + base=10000, + is_neox_style=False, + dtype=input_dtype, + device=device, + ) + pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Build paged metadata (single request with all tokens) + kv_append_length = torch.tensor( + [num_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length, dim=0), + ] + ) + num_pages_per_req = torch.tensor( + [(num_tokens + page_size - 1) // page_size] + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(num_pages_per_req, dim=0), + ] + ) + kv_page_indices = torch.arange( + kv_page_indptr[-1].item(), dtype=torch.int32, device=device + ) + kv_last_page_len = torch.tensor( + [num_tokens % page_size if num_tokens % page_size != 0 else page_size] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + + # Get batch_indices and positions + seq_lens = flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size) + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, seq_lens, num_tokens + ) + + # Allocate caches + max_pages = kv_page_indptr[-1].item() + + if config_name == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + paged_kv_cache = (ckv_cache, kpe_cache) + else: + # GQA/MHA: use NHD layout + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + paged_kv_cache = (k_cache, v_cache) + + run_idx = 0 + + def execute(): + if single_run: + import torch.cuda.nvtx as nvtx + + nvtx.range_push("rope_append") + nonlocal run_idx + run_idx += 1 + + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope=q_rope, + k_rope=k_rope, + q_nope=q_nope, + k_nope=k_nope, + v=v, + cos_sin_cache=rope_ref.cos_sin_cache, + pos_ids=pos_ids, + paged_kv_cache=paged_kv_cache, + kv_indices=kv_page_indices, + kv_indptr=kv_page_indptr, + batch_indices=batch_indices, + positions=positions, + page_size=page_size, + kv_layout="NHD" if config_name != "mla" else "NHD", + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + if single_run: + # Ensure kernels complete inside the NVTX range for ncu filtering + torch.cuda.synchronize() + nvtx.range_pop() + + if single_run: + execute() + return None, None, None, None, None + measurements = bench_gpu_time_with_cudagraph(execute) + + # Calculate I/O bytes + # Inputs: q_rope, k_rope, q_nope, k_nope, v (if not MLA), cos_sin_cache, pos_ids + io_bytes = ( + q_rope.numel() * q_rope.element_size() + + k_rope.numel() * k_rope.element_size() + + q_nope.numel() * q_nope.element_size() + + k_nope.numel() * k_nope.element_size() + + rope_ref.cos_sin_cache.numel() * rope_ref.cos_sin_cache.element_size() + + pos_ids.numel() * pos_ids.element_size() + ) + + if v is not None: + io_bytes += v.numel() * v.element_size() + + # Outputs: q_rope_out, q_nope_out (FP8), cache writes (FP8) + io_bytes += ( + q_rope.numel() * torch.finfo(quant_dtype).bits // 8 + + q_nope.numel() * torch.finfo(quant_dtype).bits // 8 + ) + + if config_name == "mla": + # MLA writes to ckv_cache and kpe_cache + io_bytes += ( + num_tokens * no_rope_dim * torch.finfo(quant_dtype).bits // 8 + + num_tokens * rope_dim * torch.finfo(quant_dtype).bits // 8 + ) + else: + # GQA/MHA writes to k_cache and v_cache + io_bytes += ( + num_tokens * num_kv_heads * head_dim * torch.finfo(quant_dtype).bits // 8 + + num_tokens * num_kv_heads * head_dim * torch.finfo(quant_dtype).bits // 8 + ) + + # Calculate statistics + ms = np.median(measurements) + min_ms = np.percentile(measurements, 20) + max_ms = np.percentile(measurements, 80) + + # Calculate bandwidth in GB/s + bandwidth_gb_s = io_bytes / ms / 1e6 + + # Calculate TFLOPs (FP operations) + # RoPE: 6 FLOPs per dimension pair (2 muls + 1 sub for real, 2 muls + 1 add for imag) + # For Q: num_tokens * num_qo_heads * (rope_dim/2) pairs * 6 FLOPs + # For K: depends on architecture + q_flops = num_tokens * num_qo_heads * (rope_dim / 2) * 6 + + if config_name == "mla": + # MLA: K is 2D (no head dimension) + k_flops = num_tokens * (rope_dim / 2) * 6 + else: + # GQA/MHA: K is 3D (has head dimension) + k_flops = num_tokens * num_kv_heads * (rope_dim / 2) * 6 + + total_flops = q_flops + k_flops + tflops = ( + total_flops / ms / 1e9 + ) # TFLOPs (operations per ms = operations per second / 1e12) + + return ms, min_ms, max_ms, bandwidth_gb_s, tflops + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ncu-single", action="store_true", help="Run a single execute() for ncu" + ) + parser.add_argument( + "--config", type=str, default="", help="Config name: mla/gqa/mha" + ) + parser.add_argument("--num-tokens", type=int, default=0) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--enable-pdl", type=int, default=0) + args, unknown = parser.parse_known_args() + + if args.ncu_single: + # Minimal single-run for ncu profiling + cfg = args.config or "mla" + ntok = int(args.num_tokens) + pgsz = int(args.page_size) + en_pdl = bool(int(args.enable_pdl)) + # Force a single execution path + benchmark_config(cfg, ntok, page_size=pgsz, enable_pdl=en_pdl, single_run=True) + sys.exit(0) + + # Get GPU information (for display only) + device = torch.device("cuda:0") + gpu_name = torch.cuda.get_device_name(0) + gpu_peak_bandwidth = get_gpu_memory_bandwidth(device) + print(f"\nDetected GPU: {gpu_name}") + print(f"Theoretical Peak Memory Bandwidth: {gpu_peak_bandwidth:.2f} GB/s") + print() + + # Token counts to benchmark + token_counts = [1, 32, 128, 384, 768, 1024, 2048, 4096, 8192] + + # Helper function to print a table for a specific configuration + def print_config_table(config_name, config_desc): + page_size_to_benchmark = 32 + print(f"\n{'=' * 100}") + print(f" {config_name.upper()}: {config_desc}") + print(f"{'=' * 100}") + + print( + f"{'Tokens':<10} {'Time (ms)':<12} {'BW (GB/s)':<12} {'BW% (Peak)':<14} {'TFLOPs':<12}" + ) + print("-" * 70) + for num_tokens in token_counts: + ms, _, _, bw, tflops = benchmark_config( + config_name, num_tokens, page_size=page_size_to_benchmark + ) + bw_pct = (bw / gpu_peak_bandwidth) * 100 + print( + f"{num_tokens:<10} {ms:<12.5f} {bw:<12.2f} {bw_pct:<14.1f} {tflops:<12.3f}" + ) + + # Print tables for each configuration + print_config_table("mla", "128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)") + print_config_table("gqa", "32 Q heads, 8 K heads, 64+64 dims (Llama-style)") + print_config_table("mha", "32 Q heads, 32 K heads, 64+64 dims (Standard)") + + print("\n" + "=" * 100) + print("Configuration details:") + print(" Page size: 32, Batch size: 4") + print(" Token range: 1 (single decode) โ†’ 8192 (large prefill)") + print(f" GPU: {gpu_name}") + print(f" Theoretical Peak Memory Bandwidth: {gpu_peak_bandwidth:.2f} GB/s") + print(" BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100") + print("=" * 100) diff --git a/csrc/flashinfer_rope_binding.cu b/csrc/flashinfer_rope_binding.cu index 23124064d8..94809da735 100644 --- a/csrc/flashinfer_rope_binding.cu +++ b/csrc/flashinfer_rope_binding.cu @@ -45,9 +45,19 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope TensorView pos_ids, double quant_scale_q, double quant_scale_kv, bool interleave, bool enable_pdl); +void rope_quantize_append_paged_kv_cache( + TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in, + TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache, + TensorView pos_ids, TensorView k_cache, TensorView v_cache, TensorView ckv_cache, + TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, + TensorView positions, int64_t kv_layout_code, int64_t page_size, double quant_scale_q, + double quant_scale_kv, bool interleave, bool enable_pdl); + TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids, apply_rope_pos_ids); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope_pos_ids, apply_llama31_rope_pos_ids); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids_cos_sin_cache, apply_rope_pos_ids_cos_sin_cache); TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize_append_paged_kv_cache, + rope_quantize_append_paged_kv_cache); diff --git a/csrc/rope.cu b/csrc/rope.cu index 78cdcad405..40388d9412 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -420,3 +420,198 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope }); }); } + +/*! + * TVM FFI binding for fused RoPE + quantization + paged KV cache append kernel + * + * Validates tensor shapes, dimensions, and data types, then dispatches to the templated + * RopeQuantizeAppendPagedKVCache CUDA kernel implementation. + */ +void rope_quantize_append_paged_kv_cache( + TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in, + TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache, + TensorView pos_ids, + // Paged cache tensors + TensorView k_cache, TensorView v_cache, TensorView ckv_cache, TensorView kpe_cache, + TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, TensorView positions, + int64_t kv_layout_code, int64_t page_size, double quant_scale_q, double quant_scale_kv, + bool interleave, bool enable_pdl) { + // Validate inputs + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_rope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_nope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_out); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_out); + CHECK_INPUT(cos_sin_cache); + CHECK_INPUT(pos_ids); + CHECK_INPUT(kv_indices); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(batch_indices); + CHECK_INPUT(positions); + + // Extract dimensions + uint32_t rope_dim = q_rope_in.size(-1); + uint32_t no_rope_dim = q_nope_in.size(-1); + uint32_t nnz = q_rope_in.size(0); + uint32_t num_qo_heads = q_rope_in.size(1); + + // Validate dimensions + TVM_FFI_ICHECK_EQ(k_rope_in.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(k_nope_in.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(q_rope_out.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(q_nope_out.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_rope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_nope_in.dtype()); + + // Validate input/output dtypes + TVM_FFI_ICHECK(q_rope_in.dtype() == dl_float16 || q_rope_in.dtype() == dl_bfloat16) + << "Input dtype must be float16 or bfloat16"; + TVM_FFI_ICHECK(q_rope_out.dtype() == dl_float8_e4m3fn || q_rope_out.dtype() == dl_float8_e5m2) + << "Output dtype must be float8_e4m3fn or float8_e5m2"; + + // Q tensors are always 3D + CHECK_DIM(3, q_rope_in); + CHECK_DIM(3, q_nope_in); + CHECK_DIM(3, q_rope_out); + CHECK_DIM(3, q_nope_out); + + // Detect architecture based on cache presence/layout (not K dimensionality) + QKVLayout kv_layout = QKVLayout(kv_layout_code); + bool has_mla_caches = (ckv_cache.data_ptr() != nullptr && kpe_cache.data_ptr() != nullptr); + bool has_gqa_caches = (k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr); + bool is_mla = has_mla_caches && !has_gqa_caches; + uint32_t num_kv_heads; + uint32_t batch_size = kv_indptr.size(0) - 1; + + // Require 3D K tensors in both paths; for MLA head dim must be 1 + CHECK_DIM(3, k_rope_in); + CHECK_DIM(3, k_nope_in); + if (is_mla) { + num_kv_heads = 1; + TVM_FFI_ICHECK_EQ(k_rope_in.size(1), 1) << "MLA expects K rope head dim == 1"; + TVM_FFI_ICHECK_EQ(k_nope_in.size(1), 1) << "MLA expects K nope head dim == 1"; + // V can be empty for MLA + TVM_FFI_ICHECK(v_in.data_ptr() == nullptr || v_in.size(0) == 0) + << "MLA should not have V input (or it should be empty)"; + // Validate MLA cache tensors are provided + TVM_FFI_ICHECK(ckv_cache.data_ptr() != nullptr && kpe_cache.data_ptr() != nullptr) + << "MLA requires ckv_cache and kpe_cache"; + CHECK_DIM(3, ckv_cache); // (max_pages, page_size, ckv_dim) + CHECK_DIM(3, kpe_cache); // (max_pages, page_size, kpe_dim) + TVM_FFI_ICHECK_EQ(ckv_cache.size(2), no_rope_dim); + TVM_FFI_ICHECK_EQ(kpe_cache.size(2), rope_dim); + } else { + // GQA/MHA validation + num_kv_heads = k_rope_in.size(1); + TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads); + // V is required for GQA/MHA + CHECK_DIM(3, v_in); + TVM_FFI_ICHECK_EQ(v_in.size(0), nnz); + TVM_FFI_ICHECK_EQ(v_in.size(1), num_kv_heads); + // Validate GQA/MHA cache tensors are provided + TVM_FFI_ICHECK(k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) + << "GQA/MHA requires k_cache and v_cache"; + // Cache must be 4D + CHECK_DIM(4, k_cache); + CHECK_DIM(4, v_cache); + } + + // Extract Q strides + const uint32_t q_rope_in_stride_n = q_rope_in.stride(0); + const uint32_t q_rope_in_stride_h = q_rope_in.stride(1); + const uint32_t q_nope_in_stride_n = q_nope_in.stride(0); + const uint32_t q_nope_in_stride_h = q_nope_in.stride(1); + const uint32_t q_rope_out_stride_n = q_rope_out.stride(0); + const uint32_t q_rope_out_stride_h = q_rope_out.stride(1); + const uint32_t q_nope_out_stride_n = q_nope_out.stride(0); + const uint32_t q_nope_out_stride_h = q_nope_out.stride(1); + + // Extract K strides + uint32_t k_rope_in_stride, k_nope_in_stride; + uint32_t k_rope_in_stride_h, k_nope_in_stride_h; + uint32_t v_in_stride = 0, v_in_stride_h = 0; + + k_rope_in_stride = k_rope_in.stride(0); + k_nope_in_stride = k_nope_in.stride(0); + k_rope_in_stride_h = k_rope_in.stride(1); + k_nope_in_stride_h = k_nope_in.stride(1); + if (!is_mla) { + v_in_stride = v_in.stride(0); + v_in_stride_h = v_in.stride(1); + } + + cudaSetDevice(q_rope_in.device().device_id); + const cudaStream_t stream = get_stream(q_rope_in.device()); + + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] { + cudaError_t status; + + if (is_mla) { + // MLA: Construct paged_kv_mla_t struct + auto ckv_strides = ckv_cache.strides(); + auto kpe_strides = kpe_cache.strides(); + + paged_kv_mla_t paged_kv_mla( + page_size, no_rope_dim, rope_dim, batch_size, + static_cast(ckv_cache.data_ptr()), ckv_strides.data(), + static_cast(kpe_cache.data_ptr()), kpe_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedMLACache( + static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv_mla, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, + q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, + q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, + k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave, + enable_pdl, stream); + + } else { + // GQA/MHA: Construct paged_kv_t struct + auto k_strides = k_cache.strides(); + auto v_strides = v_cache.strides(); + uint32_t head_dim = rope_dim + no_rope_dim; + + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout, + static_cast(k_cache.data_ptr()), + static_cast(v_cache.data_ptr()), k_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedKVCache( + static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), + static_cast(v_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, + no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, + q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, + q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, + k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv, + interleave, enable_pdl, stream); + } + + TVM_FFI_ICHECK(status == cudaSuccess) + << "RopeQuantizeAppendPagedKVCache failed with error code " << cudaGetErrorString(status); + return true; + }); + }); +} diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 7884c439be..dea6995bcf 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -226,6 +226,105 @@ def _fake_rope_quantize( pass +@register_custom_op( + "flashinfer::rope_quantize_append_paged_kv_cache", + mutates_args=( + "q_rope_out", + "q_nope_out", + "k_cache", + "v_cache", + "ckv_cache", + "kpe_cache", + ), +) +def _rope_quantize_fp8_append_paged_kv_cache( + q_rope_in: torch.Tensor, + k_rope_in: torch.Tensor, + q_nope_in: torch.Tensor, + k_nope_in: torch.Tensor, + v_in: torch.Tensor, + q_rope_out: torch.Tensor, + q_nope_out: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + ckv_cache: torch.Tensor, + kpe_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + kv_layout_code: int, + page_size: int, + quant_scale_q: float, + quant_scale_kv: float, + interleave: bool, + enable_pdl: bool, +) -> None: + r"""Custom operator that routes to the CUDA kernel implementation. + + Fuses RoPE application, FP8 quantization, and paged KV cache append into a single kernel. + + Converts is_neox parameter to interleave format and dispatches to the underlying + CUDA kernel via the JIT-compiled module. + """ + get_rope_module().rope_quantize_append_paged_kv_cache( + q_rope_in, + k_rope_in, + q_nope_in, + k_nope_in, + v_in, + q_rope_out, + q_nope_out, + cos_sin_cache, + pos_ids, + k_cache, + v_cache, + ckv_cache, + kpe_cache, + kv_indices, + kv_indptr, + batch_indices, + positions, + kv_layout_code, + page_size, + quant_scale_q, + quant_scale_kv, + interleave, + enable_pdl, + ) + + +@register_fake_op("flashinfer::rope_quantize_append_paged_kv_cache") +def _fake_rope_quantize_fp8_append_paged_kv_cache( + q_rope_in: torch.Tensor, + k_rope_in: torch.Tensor, + q_nope_in: torch.Tensor, + k_nope_in: torch.Tensor, + v_in: torch.Tensor, + q_rope_out: torch.Tensor, + q_nope_out: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + ckv_cache: torch.Tensor, + kpe_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + kv_layout_code: int, + page_size: int, + quant_scale_q: float, + quant_scale_kv: float, + interleave: bool, + enable_pdl: bool, +) -> None: + pass + + @register_custom_op( "flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope") ) @@ -1186,8 +1285,8 @@ def mla_rope_quantize_fp8( def rope_quantize_fp8( q_rope: torch.Tensor, k_rope: torch.Tensor, - q_nope: torch.Tensor, - k_nope: torch.Tensor, + q_nope: Optional[torch.Tensor], + k_nope: Optional[torch.Tensor], cos_sin_cache: torch.Tensor, pos_ids: torch.Tensor, is_neox: bool = True, @@ -1214,12 +1313,12 @@ def rope_quantize_fp8( k_rope : torch.Tensor Key tensor (rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, rope_dim)``. For MLA: ``(nnz, rope_dim)``. Must be float16 or bfloat16. - q_nope : torch.Tensor + q_nope : Optional[torch.Tensor] Query tensor (non-rotary dimensions), shape: ``(nnz, num_qo_heads, no_rope_dim)``. - Must be float16 or bfloat16. - k_nope : torch.Tensor + If ``None``, treated as zero-dim: a size-0 tensor will be created internally. + k_nope : Optional[torch.Tensor] Key tensor (non-rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, no_rope_dim)``. - For MLA: ``(nnz, no_rope_dim)``. Must be float16 or bfloat16. + For MLA: ``(nnz, no_rope_dim)``. If ``None``, treated as zero-dim and created internally. cos_sin_cache : torch.Tensor Precomputed cosine and sine values, shape: ``(max_seq_len, rope_dim)``. First half contains cosine values, second half contains sine values. Must be float32. @@ -1254,6 +1353,23 @@ def rope_quantize_fp8( if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") + # Allow None for nope tensors and normalize to size-0 tensors with correct shapes + nnz = q_rope.shape[0] + num_qo_heads = q_rope.shape[1] + is_mla = k_rope.ndim == 2 + num_kv_heads = 1 if is_mla else k_rope.shape[1] + if q_nope is None: + q_nope = torch.empty( + nnz, num_qo_heads, 0, dtype=q_rope.dtype, device=q_rope.device + ) + if k_nope is None: + if is_mla: + k_nope = torch.empty(nnz, 0, dtype=k_rope.dtype, device=k_rope.device) + else: + k_nope = torch.empty( + nnz, num_kv_heads, 0, dtype=k_rope.dtype, device=k_rope.device + ) + # Infer quantize_dtype from output tensors or default to float8_e4m3fn if quantize_dtype is None: for out in (q_rope_out, k_rope_out, q_nope_out, k_nope_out): @@ -1303,3 +1419,239 @@ def rope_quantize_fp8( ) return q_rope_out, k_rope_out, q_nope_out, k_nope_out + + +def rope_quantize_fp8_append_paged_kv_cache( + q_rope: torch.Tensor, + k_rope: torch.Tensor, + q_nope: Optional[torch.Tensor], + k_nope: Optional[torch.Tensor], + v: Optional[torch.Tensor], + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + paged_kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + is_neox: bool = True, + quantize_dtype: Optional[torch.dtype] = None, + quant_scale_q: float = 1.0, + quant_scale_kv: float = 1.0, + page_size: int = 16, + kv_layout: str = "NHD", + q_rope_out: Optional[torch.Tensor] = None, + q_nope_out: Optional[torch.Tensor] = None, + enable_pdl: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Apply RoPE (Rotary Positional Embeddings), quantize to FP8, and append K/V to paged cache. + + This fused function applies RoPE to query/key (Q/K) rotary dimension tensors, quantizes all Q/K tensors + (and V for GQA/MHA) to FP8 format, and directly appends the quantized K/V to a paged KV cache. + It returns quantized Q tensors for use in attention computation. Supports MLA, GQA, and MHA + architectures with automatic detection based on input tensor shapes. + + Parameters + ---------- + q_rope : torch.Tensor + Query tensor (rotary dimensions), shape: ``(nnz, num_qo_heads, rope_dim)``. + Must be float16 or bfloat16. + k_rope : torch.Tensor + Key tensor (rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, rope_dim)``. + For MLA: ``(nnz, rope_dim)``. Must be float16 or bfloat16. + q_nope : torch.Tensor + Query tensor (non-rotary dimensions), shape: ``(nnz, num_qo_heads, no_rope_dim)``. + Must be float16 or bfloat16. + k_nope : torch.Tensor + Key tensor (non-rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, no_rope_dim)``. + For MLA: ``(nnz, no_rope_dim)``. Must be float16 or bfloat16. + v : Optional[torch.Tensor] + Value tensor for GQA/MHA: ``(nnz, num_kv_heads, head_dim)``. Must be float16 or bfloat16. + For MLA: pass ``None`` (MLA does not use separate V; K non-RoPE acts as compressed KV). + cos_sin_cache : torch.Tensor + Precomputed cosine and sine values, shape: ``(max_seq_len, rope_dim)``. + First half contains cosine values, second half contains sine values. Must be float32. + pos_ids : torch.Tensor + Position indices for each token, shape: ``(nnz,)``. + paged_kv_cache : Tuple[torch.Tensor, torch.Tensor] + For MLA: ``(ckv_cache, kpe_cache)`` where: + - ckv_cache: ``(max_pages, page_size, no_rope_dim)`` in FP8 + - kpe_cache: ``(max_pages, page_size, rope_dim)`` in FP8 + For GQA/MHA: ``(k_cache, v_cache)`` where: + - k_cache: ``(max_pages, page_size, num_kv_heads, head_dim)`` or + ``(max_pages, num_kv_heads, page_size, head_dim)`` depending on layout, in FP8 + - v_cache: same shape as k_cache, in FP8 + kv_indices : torch.Tensor + Page indices mapping, shape: ``(total_pages,)``. Typically ``torch.arange(total_pages)``. + kv_indptr : torch.Tensor + Page indptr array for each request, shape: ``(batch_size + 1,)``. + ``kv_indptr[i]`` is the starting page index for request ``i``. + batch_indices : torch.Tensor + Batch index for each token, shape: ``(nnz,)``. Maps each token to its request. + positions : torch.Tensor + Position within each request's sequence for each token, shape: ``(nnz,)``. + is_neox : bool + RoPE layout style. If ``True`` (default), use non-interleaved layout (first/second half). + If ``False``, use interleaved layout (even/odd dimensions). + quantize_dtype : Optional[torch.dtype] + Target quantization dtype. If ``None``, inferred from output tensors or defaults to + ``torch.float8_e4m3fn``. Must be ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. + quant_scale_q : float + Quantization scaling factor for query tensors, default: ``1.0``. + quant_scale_kv : float + Quantization scaling factor for key/value tensors, default: ``1.0``. + page_size : int + Number of entries per page in the paged cache, default: ``16``. + kv_layout : str + Cache memory layout for GQA/MHA. Options: ``"NHD"`` (page, seq, head, dim) or + ``"HND"`` (page, head, seq, dim). Default: ``"NHD"``. Ignored for MLA. + q_rope_out : Optional[torch.Tensor] + Pre-allocated output tensor for quantized query (rotary). If ``None``, allocated automatically. + q_nope_out : Optional[torch.Tensor] + Pre-allocated output tensor for quantized query (non-rotary). If ``None``, allocated automatically. + enable_pdl : bool + Whether to enable PDL (Programmatic Dependent Launch). Default: ``False``. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Quantized query tensors: (q_rope_out, q_nope_out). + K/V are written directly to the paged cache and not returned. + + Notes + ----- + - Architecture detection: Automatically distinguishes MLA (2D K tensors) from GQA/MHA (3D K tensors). + - MLA writes K-RoPE to ``kpe_cache`` and K-noRoPE to ``ckv_cache``; V is not used. + - GQA/MHA writes full K (RoPE+noRoPE) to ``k_cache`` and V to ``v_cache``. + - The ``batch_indices`` and ``positions`` tensors are typically obtained from + ``flashinfer.get_batch_indices_positions()``. + - Cache tensors must already be allocated in the target FP8 dtype. + """ + if cos_sin_cache.dtype != torch.float32: + raise ValueError("cos_sin_cache should be float32") + + # Detect architecture + is_mla = k_rope.ndim == 2 + + # Allow None for nope tensors and normalize to size-0 tensors with correct shapes + nnz = q_rope.shape[0] + num_qo_heads = q_rope.shape[1] + if q_nope is None: + q_nope = torch.empty( + nnz, num_qo_heads, 0, dtype=q_rope.dtype, device=q_rope.device + ) + if k_nope is None: + if is_mla: + k_nope = torch.empty(nnz, 0, dtype=k_rope.dtype, device=k_rope.device) + else: + num_kv_heads = k_rope.shape[1] + k_nope = torch.empty( + nnz, num_kv_heads, 0, dtype=k_rope.dtype, device=k_rope.device + ) + + # Infer quantize_dtype from output tensors or default + if quantize_dtype is None: + if q_rope_out is not None: + quantize_dtype = q_rope_out.dtype + elif q_nope_out is not None: + quantize_dtype = q_nope_out.dtype + else: + quantize_dtype = torch.float8_e4m3fn + + # Allocate Q output tensors if not provided + if q_rope_out is None: + q_rope_out = torch.empty_like(q_rope, dtype=quantize_dtype) + if q_nope_out is None: + q_nope_out = torch.empty_like(q_nope, dtype=quantize_dtype) + + # Handle MLA normalization and V (create empty dummy tensor, not used) + if is_mla: + # Normalize MLA K tensors to 3D (nnz, 1, dim) so C++ binding can always assume 3D + if k_rope.ndim == 2: + k_rope = k_rope.unsqueeze(1) + if k_nope.ndim == 2: + k_nope = k_nope.unsqueeze(1) + if v is None: + v = torch.empty(0, dtype=q_rope.dtype, device=q_rope.device) + else: + raise ValueError("MLA should not have V input (pass None)") + + # Unpack and validate cache tensors + if len(paged_kv_cache) != 2: + raise ValueError("paged_kv_cache must be a tuple of 2 tensors") + + cache_0, cache_1 = paged_kv_cache + + if is_mla: + # MLA: Expect (ckv_cache, kpe_cache) + ckv_cache = cache_0 + kpe_cache = cache_1 + if ckv_cache.dtype != quantize_dtype or kpe_cache.dtype != quantize_dtype: + raise ValueError( + f"MLA cache dtype mismatch: expected {quantize_dtype}, " + f"got ckv={ckv_cache.dtype}, kpe={kpe_cache.dtype}" + ) + if ckv_cache.ndim != 3 or kpe_cache.ndim != 3: + raise ValueError( + f"MLA cache must be 3D: (max_pages, page_size, dim), " + f"got ckv={ckv_cache.ndim}D, kpe={kpe_cache.ndim}D" + ) + # Create dummy tensors for GQA/MHA cache (not used) + k_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + v_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + else: + # GQA/MHA: Expect (k_cache, v_cache) + k_cache = cache_0 + v_cache = cache_1 + # Validate V input is provided for GQA/MHA + if v is None: + raise ValueError( + "GQA/MHA expects a V tensor, but got None. " + "Only MLA uses None for V (compressed KV representation)." + ) + if k_cache.dtype != quantize_dtype or v_cache.dtype != quantize_dtype: + raise ValueError( + f"GQA/MHA cache dtype mismatch: expected {quantize_dtype}, " + f"got k={k_cache.dtype}, v={v_cache.dtype}" + ) + if k_cache.ndim != 4 or v_cache.ndim != 4: + raise ValueError( + f"GQA/MHA cache must be 4D, got k={k_cache.ndim}D, v={v_cache.ndim}D" + ) + # Create dummy tensors for MLA cache (not used) + ckv_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + kpe_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + + # Import TensorLayout enum + from .utils import TensorLayout + + kv_layout_code = TensorLayout[kv_layout].value + + # Call custom op + _rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + v, + q_rope_out, + q_nope_out, + cos_sin_cache, + pos_ids, + k_cache, + v_cache, + ckv_cache, + kpe_cache, + kv_indices, + kv_indptr, + batch_indices, + positions, + kv_layout_code, + page_size, + quant_scale_q, + quant_scale_kv, + not is_neox, # interleave + enable_pdl, + ) + + return q_rope_out, q_nope_out diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 771d616380..76689bab84 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -21,6 +21,7 @@ import torch import torch.version +import pynvml from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version import inspect @@ -255,6 +256,46 @@ def get_compute_capability(device: torch.device) -> Tuple[int, int]: return torch.cuda.get_device_capability(device.index) +@functools.cache +def get_gpu_memory_bandwidth(device: torch.device) -> float: + """ + Get GPU memory bandwidth in GB/s for the specified CUDA device. + + Args: + device: torch.device object, e.g., torch.device('cuda:0') + + Returns: + float: GPU memory bandwidth (GB/s) + + Raises: + ValueError: If device is not a CUDA device + """ + # Convert to torch.device object if string is passed + if isinstance(device, str): + device = torch.device(device) + + # Check if it's a CUDA device + if device.type != "cuda": + raise ValueError(f"Device must be a CUDA device, got {device}") + + # Get device index + device_index = device.index if device.index is not None else 0 + + # Use pynvml to get bandwidth + pynvml.nvmlInit() + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle) + mem_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_MEM) + + # Calculate theoretical peak bandwidth (GB/s) + bandwidth = (mem_clock * bus_width * 2) / 8 / 1000 + + return bandwidth + finally: + pynvml.nvmlShutdown() + + def _check_cached_qkv_data_type( q: torch.Tensor, k: torch.Tensor, dtype_q: torch.dtype, dtype_kv: torch.dtype ) -> None: diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 7547a06090..7901b71e22 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -20,14 +20,40 @@ #include #include #include +#include #include "layout.cuh" #include "math.cuh" +#include "page.cuh" #include "utils.cuh" #include "vec_dtypes.cuh" namespace flashinfer { +struct RopeQuantizeAppendPagedKVCacheParams { + uint32_t nnz; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t rope_dim; + uint32_t no_rope_dim; + size_t q_rope_in_stride_n; + size_t q_rope_in_stride_h; + size_t q_nope_in_stride_n; + size_t q_nope_in_stride_h; + size_t q_rope_out_stride_n; + size_t q_rope_out_stride_h; + size_t q_nope_out_stride_n; + size_t q_nope_out_stride_h; + size_t k_rope_in_stride; + size_t k_rope_in_stride_h; + size_t k_nope_in_stride; + size_t k_nope_in_stride_h; + size_t v_in_stride; + size_t v_in_stride_h; + float quant_scale_q; + float quant_scale_kv; +}; + /*! * \brief An enumeration class that defines different modes for applying RoPE * (Rotary Positional Embeddings). @@ -384,7 +410,7 @@ __global__ void RopeQuantizeKernel( // 2. if not interleave // - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)] // - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)] - if ((tx * vec_size < rope_dim) and (by < k_rope_end)) { + if ((tx * vec_size < rope_dim) && (by < k_rope_end)) { int sin_offset = rope_dim / 2; int vec_idx; if constexpr (interleave) { @@ -717,34 +743,237 @@ __global__ void BatchQKApplyRotaryKernel( } } -#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ - if (interleave) { \ - const bool INTERLEAVE = true; \ - __VA_ARGS__ \ - } else { \ - const bool INTERLEAVE = false; \ - __VA_ARGS__ \ - } +/*! + * \brief Unified CUDA kernel to apply RoPE, quantize to FP8, and append to paged cache. + * + * Templated on CacheT to support both GQA/MHA (paged_kv_t) and MLA (paged_kv_mla_t). + * Cache-only behaviors are selected with constexpr on the CacheT. + */ +template +__global__ void RopeQuantizeAppendPagedKVCacheKernel( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, + QuantType* q_rope_out, QuantType* q_nope_out, CacheT paged_kv_like, + IdType* __restrict__ batch_indices, IdType* __restrict__ positions, + float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids, + const RopeQuantizeAppendPagedKVCacheParams params) { +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + uint32_t by = blockIdx.y; + uint32_t bdy = blockDim.y; + + // Local aliases for params for readability + const uint32_t nnz = params.nnz; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t rope_dim = params.rope_dim; + const uint32_t no_rope_dim = params.no_rope_dim; + const size_t q_rope_in_stride_n = params.q_rope_in_stride_n; + const size_t q_rope_in_stride_h = params.q_rope_in_stride_h; + const size_t q_nope_in_stride_n = params.q_nope_in_stride_n; + const size_t q_nope_in_stride_h = params.q_nope_in_stride_h; + const size_t q_rope_out_stride_n = params.q_rope_out_stride_n; + const size_t q_rope_out_stride_h = params.q_rope_out_stride_h; + const size_t q_nope_out_stride_n = params.q_nope_out_stride_n; + const size_t q_nope_out_stride_h = params.q_nope_out_stride_h; + const size_t k_rope_in_stride = params.k_rope_in_stride; + const size_t k_rope_in_stride_h = params.k_rope_in_stride_h; + const size_t k_nope_in_stride = params.k_nope_in_stride; + const size_t k_nope_in_stride_h = params.k_nope_in_stride_h; + const size_t v_in_stride = params.v_in_stride; + const size_t v_in_stride_h = params.v_in_stride_h; + const float quant_scale_q = params.quant_scale_q; + const float quant_scale_kv = params.quant_scale_kv; + + // Calculate flexible boundaries for block allocation + uint32_t rope_chunk_size = rope_dim; + uint32_t rope_chunks = (rope_dim + rope_chunk_size - 1) / rope_chunk_size; + uint32_t no_rope_chunks = (no_rope_dim + rope_chunk_size - 1) / rope_chunk_size; + + uint32_t q_rope_end = num_qo_heads * rope_chunks; + // For MLA, num_kv_heads is effectively 1 + uint32_t k_rope_end = q_rope_end + num_kv_heads * rope_chunks; + uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks; + + // Deduce MLA vs GQA/MHA from CacheT + constexpr bool IS_MLA = std::is_same>::value; + + vec_t cos, sin; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + + // Compute page location for this token + uint32_t page_iter, entry_idx; + paged_kv_like.page_size.divmod( + paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx], + page_iter, entry_idx); + + const int half_rope_dim = rope_dim / 2; + // Load cos/sin for RoPE processing blocks only + if ((tx * vec_size < rope_dim) && (by < k_rope_end)) { + int sin_offset = rope_dim / 2; + int vec_idx; + if constexpr (interleave) { + vec_idx = (tx * vec_size) / 2; // Force integer division + } else { + vec_idx = (tx * vec_size) % half_rope_dim; + } + cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx); + sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx)); + } + + if (by < q_rope_end) { + // ============ Q RoPE processing ============ + uint32_t q_head_idx = by / rope_chunks; + uint32_t rope_chunk_idx = by % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + + DType* q_rope_in_ptr = + q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n, + q_rope_in_stride_h); + QuantType* q_rope_out_ptr = + q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n, + q_rope_out_stride_h); + + vec_t q_rope_vec; + if constexpr (interleave) { + q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + q_rope_in_ptr, cos, sin, rope_dim); + } else { + q_rope_vec = vec_apply_llama_rope_cos_sin(q_rope_in_ptr, cos, sin, rope_dim); + } +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + q_rope_vec[i] = q_rope_vec[i] * quant_scale_q; + } + q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); + + } else if (by < k_rope_end) { + // ============ K RoPE processing & Cache Append ============ + uint32_t k_head_idx = (by - q_rope_end) / rope_chunks; + uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + + DType* k_rope_in_ptr; + if constexpr (IS_MLA) { + // MLA: 2D K + k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset; + } else { + // GQA/MHA: 3D K + k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_rope_in_stride, k_rope_in_stride_h); + } + + vec_t k_rope_vec; + if constexpr (interleave) { + k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + k_rope_in_ptr, cos, sin, rope_dim); + } else { + k_rope_vec = vec_apply_llama_rope_cos_sin(k_rope_in_ptr, cos, sin, rope_dim); + } +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv; + } + + if constexpr (IS_MLA) { + QuantType* kpe_ptr = + paged_kv_like.get_kpe_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); + k_rope_vec.cast_store(kpe_ptr); + } else { + QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size); + k_rope_vec.cast_store(k_ptr); + } + + } else if (by < k_nope_end) { + // ============ K Non-RoPE processing & Cache Append ============ + uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + + DType* k_nope_in_ptr; + if constexpr (IS_MLA) { + k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset; + } else { + k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_nope_in_stride, k_nope_in_stride_h); + } + + vec_t k_nope_vec; + k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; + } -#define DISPATCH_ROPE_DIM(rope_dim, vec_size, ...) \ - if (rope_dim == 16) { \ - constexpr uint32_t bdx = 16 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 32) { \ - constexpr uint32_t bdx = 32 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 64) { \ - constexpr uint32_t bdx = 64 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 128) { \ - constexpr uint32_t bdx = 128 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 256) { \ - constexpr uint32_t bdx = 256 / vec_size; \ - __VA_ARGS__ \ - } else { \ - FLASHINFER_ERROR("Unsupported rope_dim. Supported values: 16, 32, 64, 128, 256"); \ + if constexpr (IS_MLA) { + QuantType* ckv_ptr = + paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); + k_nope_vec.cast_store(ckv_ptr); + } else { + QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, + rope_dim + elem_offset + tx * vec_size); + k_nope_vec.cast_store(k_ptr); + } + + } else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) { + // ============ V processing & Cache Append (GQA/MHA only) ============ + if constexpr (!IS_MLA) { + uint32_t kv_head_idx = by - k_nope_end; + DType* v_in_ptr = + v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h); + // Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size + uint32_t head_dim_total = rope_dim + no_rope_dim; + uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size; +#pragma unroll 1 + for (uint32_t j = 0; j < v_chunks; ++j) { + uint32_t v_elem_offset = j * rope_chunk_size; + if (v_elem_offset + tx * vec_size < head_dim_total) { + vec_t v_vec; + v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + v_vec[i] = v_vec[i] * quant_scale_kv; + } + QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, + v_elem_offset + tx * vec_size); + v_vec.cast_store(v_ptr); + } + } + } + + } else { + // ============ Q Non-RoPE processing ============ + // MLA has no V section, so Q-nope starts immediately after K-nope. + // GQA/MHA has a V section of length num_kv_heads blocks. + uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads); + uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + + DType* q_nope_in_ptr = + q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n, + q_nope_in_stride_h); + QuantType* q_nope_out_ptr = + q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n, + q_nope_out_stride_h); + + vec_t q_nope_vec; + q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; + } + q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); + } } +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} template cudaError_t RopeQuantize( @@ -763,11 +992,11 @@ cudaError_t RopeQuantize( FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - constexpr uint32_t vec_size = 32 / sizeof(DType); - // Use nested macros for runtime->compile-time dispatch for required constexpr values - DISPATCH_ROPE_DIM(rope_dim, vec_size, { + DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t bdx = ROPE_DIM / vec_size; uint32_t num_threads = 128U; uint32_t bdy = num_threads / bdx; uint32_t nblks_x = (nnz + bdy - 1) / bdy; @@ -838,6 +1067,185 @@ cudaError_t RopeQuantize( return cudaSuccess; } +/*! + * \brief Host function to apply RoPE, quantize to FP8, and append K/V to paged cache (GQA/MHA) + */ +template +cudaError_t RopeQuantizeAppendPagedKVCache( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, + QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t paged_kv, + IdType* batch_indices, IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rope_dim, uint32_t no_rope_dim, + size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n, + size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h, + size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, + size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h, + size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv, + bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { + DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t bdx = ROPE_DIM / vec_size; + uint32_t num_threads = 128U; + uint32_t bdy = num_threads / bdx; + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunks = 1; + uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; + + // GQA/MHA: Q rope + K rope + K nope + V + Q nope + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_kv_heads + + num_qo_heads * no_rope_chunks; + + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + auto kernel = RopeQuantizeAppendPagedKVCacheKernel>; + RopeQuantizeAppendPagedKVCacheParams params; + params.nnz = nnz; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.rope_dim = rope_dim; + params.no_rope_dim = no_rope_dim; + params.q_rope_in_stride_n = q_rope_in_stride_n; + params.q_rope_in_stride_h = q_rope_in_stride_h; + params.q_nope_in_stride_n = q_nope_in_stride_n; + params.q_nope_in_stride_h = q_nope_in_stride_h; + params.q_rope_out_stride_n = q_rope_out_stride_n; + params.q_rope_out_stride_h = q_rope_out_stride_h; + params.q_nope_out_stride_n = q_nope_out_stride_n; + params.q_nope_out_stride_h = q_nope_out_stride_h; + params.k_rope_in_stride = k_rope_in_stride; + params.k_rope_in_stride_h = k_rope_in_stride_h; + params.k_nope_in_stride = k_nope_in_stride; + params.k_nope_in_stride_h = k_nope_in_stride_h; + params.v_in_stride = v_in_stride; + params.v_in_stride_h = v_in_stride_h; + params.quant_scale_q = quant_scale_q; + params.quant_scale_kv = quant_scale_kv; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, + // inputs + q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in, + // q outputs + q_rope_out, q_nope_out, + // cache + indices + paged_kv, batch_indices, positions, + // rope tables + cos_sin_cache, pos_ids, + // params + params)); + }); + }); + + return cudaSuccess; +} + +/*! + * \brief Host function to apply RoPE, quantize to FP8, and append to MLA paged cache + */ +template +cudaError_t RopeQuantizeAppendPagedMLACache( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out, + QuantType* q_nope_out, paged_kv_mla_t paged_kv_mla, IdType* batch_indices, + IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, + uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, + size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, + size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, + size_t k_rope_in_stride, size_t k_nope_in_stride, float quant_scale_q, float quant_scale_kv, + bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { + DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t bdx = ROPE_DIM / vec_size; + uint32_t num_threads = 128U; + uint32_t bdy = num_threads / bdx; + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunks = 1; + uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; + + // MLA: Q rope + K rope + K nope + Q nope (no V) + constexpr uint32_t num_kv_heads = 1; + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_qo_heads * no_rope_chunks; + + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + auto kernel = + RopeQuantizeAppendPagedKVCacheKernel>; + // For MLA: pass v_in as nullptr, num_kv_heads=1, duplicate 2D K strides for head strides, and + // 0 V strides + DType* v_in_nullptr = nullptr; + uint32_t num_kv_heads_1 = 1; + size_t k_rope_in_stride_h_dup = k_rope_in_stride; + size_t k_nope_in_stride_h_dup = k_nope_in_stride; + size_t v_in_stride_zero = 0, v_in_stride_h_zero = 0; + RopeQuantizeAppendPagedKVCacheParams params; + params.nnz = nnz; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = 1u; + params.rope_dim = rope_dim; + params.no_rope_dim = no_rope_dim; + params.q_rope_in_stride_n = q_rope_in_stride_n; + params.q_rope_in_stride_h = q_rope_in_stride_h; + params.q_nope_in_stride_n = q_nope_in_stride_n; + params.q_nope_in_stride_h = q_nope_in_stride_h; + params.q_rope_out_stride_n = q_rope_out_stride_n; + params.q_rope_out_stride_h = q_rope_out_stride_h; + params.q_nope_out_stride_n = q_nope_out_stride_n; + params.q_nope_out_stride_h = q_nope_out_stride_h; + params.k_rope_in_stride = k_rope_in_stride; + params.k_rope_in_stride_h = k_rope_in_stride_h_dup; + params.k_nope_in_stride = k_nope_in_stride; + params.k_nope_in_stride_h = k_nope_in_stride_h_dup; + params.v_in_stride = 0; + params.v_in_stride_h = 0; + params.quant_scale_q = quant_scale_q; + params.quant_scale_kv = quant_scale_kv; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, + // inputs + q_rope_in, k_rope_in, q_nope_in, k_nope_in, + v_in_nullptr, + // q outputs + q_rope_out, q_nope_out, + // cache + indices + paged_kv_mla, batch_indices, positions, + // rope tables + cos_sin_cache, pos_ids, + // params + params)); + }); + }); + + return cudaSuccess; +} + template cudaError_t BatchQKApplyRotaryPosIdsCosSinCache( DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_sin_cache, IdType* pos_ids, diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 5b26d7beaf..0471bd1081 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -201,6 +201,52 @@ } \ } +// convert interleave to compile-time constant +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + constexpr bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, ...) \ + switch (rope_dim) { \ + case 16: { \ + constexpr uint32_t ROPE_DIM = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 32: { \ + constexpr uint32_t ROPE_DIM = 32; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: { \ + constexpr uint32_t ROPE_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr uint32_t ROPE_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 256: { \ + constexpr uint32_t ROPE_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported ROPE_DIM: " << rope_dim; \ + err_msg << ". Supported values: 16, 32, 64, 128, 256"; \ + err_msg << " in DISPATCH_ROPE_DIM"; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + #define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ switch (pos_encoding_mode) { \ case PosEncodingMode::kNone: { \ diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index da59223a4f..8e694088e5 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -394,6 +394,10 @@ def test_generalized_rope_quantize( ): """Test generalized rope + quantization for MLA, GQA, and MHA architectures.""" device = "cuda:0" + # Fixed seed for reproducibility across tests + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) total_dim = rope_dim + no_rope_dim # Create input tensors based on attention type @@ -481,6 +485,893 @@ def test_generalized_rope_quantize( ) +@pytest.mark.parametrize( + "attention_type,num_qo_heads,num_kv_heads,rope_dim,no_rope_dim", + [ + # MLA: Multiple Q heads, single shared K/V head + ("mla", 128, 1, 64, 512), + ("mla", 64, 1, 128, 256), + ("mla", 128, 1, 64, 128), # Explicit DeepSeek R1 MLA config case + ("mla", 32, 1, 32, 96), + # GQA: Multiple Q heads, fewer K/V heads (grouped) + ("gqa", 32, 8, 64, 64), + ("gqa", 64, 16, 128, 128), + ("gqa", 24, 6, 32, 96), + ("gqa", 32, 8, 128, 0), # Llama3 8B standard config + ("gqa", 64, 8, 128, 0), # Llama3 70B standard config + ("gqa", 64, 8, 64, 0), # (plausible) GPT-OSS config + # MHA: Equal Q and K/V heads + ("mha", 32, 32, 64, 64), + ("mha", 16, 16, 128, 128), + ("mha", 8, 8, 32, 96), + ], +) +@pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) +@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("enable_pdl", [True, False]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +@pytest.mark.parametrize("page_size", [16, 32]) +def test_generalized_rope_quantize_append_kv_cache( + attention_type, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + num_tokens, + input_dtype, + quant_dtype, + enable_pdl, + kv_layout, + page_size, +): + device = "cuda:0" + # Fixed seed for reproducibility + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + head_dim = rope_dim + no_rope_dim + batch_size = 4 + + # Build inputs following the same pattern used elsewhere + if attention_type == "mla": + # Q: (N, Hq, *), K: 2D (N, *) + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = ( + None + if no_rope_dim == 0 + else torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + ) + k_rope = torch.randn(num_tokens, rope_dim, dtype=input_dtype, device=device) + k_nope = ( + None + if no_rope_dim == 0 + else torch.randn(num_tokens, no_rope_dim, dtype=input_dtype, device=device) + ) + v = None + else: + # GQA/MHA: K/V are 3D + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = ( + None + if no_rope_dim == 0 + else torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + ) + k_rope = torch.randn( + num_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope = ( + None + if no_rope_dim == 0 + else torch.randn( + num_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + ) + v = torch.randn( + num_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # Cos/sin and positions + max_seq_len = 4096 + rope_ref = FlashInferRotaryEmbedding( + head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device + ) + pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Build paged metadata + kv_append_length = torch.tensor( + [num_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length, dim=0), + ] + ) + num_pages_per_req = torch.tensor( + [(num_tokens + page_size - 1) // page_size] + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(num_pages_per_req, dim=0), + ] + ) + kv_page_indices = torch.arange( + kv_page_indptr[-1].item(), dtype=torch.int32, device=device + ) + kv_last_page_len = torch.tensor( + [num_tokens % page_size if num_tokens % page_size != 0 else page_size] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + # Allocate caches sized by required pages + max_pages = kv_page_indptr[-1].item() + + # Get batch_indices and positions + seq_lens = flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size) + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, seq_lens, num_tokens + ) + + # Fused call + cache allocation + if attention_type == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + q_rope_out_fused, q_nope_out_fused = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + None, + rope_ref.cos_sin_cache, + pos_ids, + (ckv_cache, kpe_cache), + kv_page_indices, + kv_page_indptr, + batch_indices, + positions, + page_size=page_size, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + else: + # Allocate cache based on layout + if kv_layout == "NHD": + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + else: # HND + k_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + q_rope_out_fused, q_nope_out_fused = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + v, + rope_ref.cos_sin_cache, + pos_ids, + (k_cache, v_cache), + kv_page_indices, + kv_page_indptr, + batch_indices, + positions, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + # Compute reference output (handle None for no_rope_dim == 0) + q_in = q_rope if q_nope is None else torch.cat([q_rope, q_nope], dim=-1) + k_in = k_rope if k_nope is None else torch.cat([k_rope, k_nope], dim=-1) + q_out_f16_ref, k_out_f16_ref = rope_ref.forward_native(pos_ids, q_in, k_in) + q_out_f8_ref, k_out_f8_ref = map( + lambda x: x.to(quant_dtype), + (q_out_f16_ref, k_out_f16_ref), + ) + + # Fused vs Pytorch reference Q checks + torch.testing.assert_close( + q_out_f8_ref[..., :rope_dim].float(), + q_rope_out_fused.float(), + rtol=2e-1, + atol=1e-2, + ) + torch.testing.assert_close( + q_out_f8_ref[..., rope_dim:].float(), + q_nope_out_fused.float(), + rtol=2e-1, + atol=1e-2, + ) + + # expect 1-ULP differences between FP8 device rounding and PyTorch .to(fp8) + if quant_dtype == torch.float8_e4m3fn: + rtol_val, atol_val = 0.25, 0.5 + else: # quant_dtype == torch.float8_e5m2: + rtol_val, atol_val = 0.25, 1.0 + + # if MLA: check ckv_cache, kpe_cache + if attention_type == "mla": + # Split K reference + k_rope_ref = k_out_f8_ref[..., :rope_dim] + k_nope_ref = k_out_f8_ref[..., rope_dim:] + + ckv_ref = torch.zeros_like(ckv_cache) + kpe_ref = torch.zeros_like(kpe_cache) + + for i in range(num_tokens): + b = batch_indices[i].item() + pos = positions[i].item() + page_iter = (kv_page_indptr[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices[page_iter].item() + ckv_ref[page_idx, entry_idx, :] = k_nope_ref[i] + kpe_ref[page_idx, entry_idx, :] = k_rope_ref[i] + + torch.testing.assert_close( + ckv_cache.float(), ckv_ref.float(), rtol=rtol_val, atol=atol_val + ) + torch.testing.assert_close( + kpe_cache.float(), kpe_ref.float(), rtol=rtol_val, atol=atol_val + ) + + # if GQA/MHA: check k_cache, v_cache + if attention_type == "gqa" or attention_type == "mha": + # K reference + k_ref = torch.zeros_like(k_cache) + for i in range(num_tokens): + b = batch_indices[i].item() + pos = positions[i].item() + page_iter = (kv_page_indptr[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices[page_iter].item() + if kv_layout == "NHD": + k_ref[page_idx, entry_idx, :, :] = k_out_f8_ref[i] # [Hkv, head_dim] + else: # HND + k_ref[page_idx, :, entry_idx, :] = k_out_f8_ref[i] # [Hkv, head_dim] + + torch.testing.assert_close( + k_cache.float(), k_ref.float(), rtol=rtol_val, atol=atol_val + ) + + # V reference (no RoPE on V; same quant scale as KV) + quant_scale_kv = 1.0 # match fused call + v_ref_tokens = (v * quant_scale_kv).to(quant_dtype) + v_ref = torch.zeros_like(v_cache) + for i in range(num_tokens): + b = batch_indices[i].item() + pos = positions[i].item() + page_iter = (kv_page_indptr[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices[page_iter].item() + if kv_layout == "NHD": + v_ref[page_idx, entry_idx, :, :] = v_ref_tokens[i] + else: # HND + v_ref[page_idx, :, entry_idx, :] = v_ref_tokens[i] + + torch.testing.assert_close( + v_cache.float(), v_ref.float(), rtol=rtol_val, atol=atol_val + ) + + +@pytest.mark.parametrize( + "attention_type,num_qo_heads,num_kv_heads,rope_dim,no_rope_dim", + [ + # MLA: Multiple Q heads, single shared K/V head + ("mla", 128, 1, 64, 512), + ("mla", 32, 1, 32, 96), + # GQA: Multiple Q heads, fewer K/V heads (grouped) + ("gqa", 32, 8, 64, 64), + ("gqa", 32, 8, 128, 0), # Llama3 8B standard config + # MHA: Equal Q and K/V heads + ("mha", 32, 32, 64, 64), + ("mha", 16, 16, 128, 128), + ], +) +@pytest.mark.parametrize("num_existing_tokens", [10, 50]) +@pytest.mark.parametrize("num_new_tokens", [1, 8]) +@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("enable_pdl", [True, False]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +@pytest.mark.parametrize("page_size", [16, 32]) +def test_rope_quantize_fp8_append_paged_kv_cache_decode( + attention_type, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + num_existing_tokens, + num_new_tokens, + input_dtype, + quant_dtype, + enable_pdl, + kv_layout, + page_size, +): + """Test append to non-empty cache (decode/continuation scenario).""" + device = "cuda:0" + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + head_dim = rope_dim + no_rope_dim + batch_size = 2 + + # Step 1: Pre-populate cache with existing tokens + if attention_type == "mla": + q_rope_existing = torch.randn( + num_existing_tokens, + num_qo_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + q_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_existing = torch.randn( + num_existing_tokens, rope_dim, dtype=input_dtype, device=device + ) + k_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, no_rope_dim, dtype=input_dtype, device=device + ) + ) + v_existing = None + else: + q_rope_existing = torch.randn( + num_existing_tokens, + num_qo_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + q_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_existing = torch.randn( + num_existing_tokens, + num_kv_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + k_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_kv_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + v_existing = torch.randn( + num_existing_tokens, + num_kv_heads, + head_dim, + dtype=input_dtype, + device=device, + ) + + # Create RoPE reference + max_seq_len = 4096 + rope_ref = FlashInferRotaryEmbedding( + head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device + ) + pos_ids_existing = torch.arange( + num_existing_tokens, device=device, dtype=torch.int32 + ) + + # Build metadata for existing tokens (single request for simplicity) + kv_append_length_existing = torch.tensor( + [num_existing_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr_existing = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length_existing, dim=0), + ] + ) + num_pages_existing = (num_existing_tokens + page_size - 1) // page_size + kv_page_indptr_existing = torch.tensor( + [0, num_pages_existing] + [num_pages_existing] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indices_existing = torch.arange( + num_pages_existing, dtype=torch.int32, device=device + ) + kv_last_page_len_existing = torch.tensor( + [ + num_existing_tokens % page_size + if num_existing_tokens % page_size != 0 + else page_size + ] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + seq_lens_existing = flashinfer.get_seq_lens( + kv_page_indptr_existing, kv_last_page_len_existing, page_size + ) + batch_indices_existing, positions_existing = flashinfer.get_batch_indices_positions( + kv_append_indptr_existing, seq_lens_existing, num_existing_tokens + ) + + # Allocate cache sized for existing + new tokens + total_tokens = num_existing_tokens + num_new_tokens + max_pages = (total_tokens + page_size - 1) // page_size + + if attention_type == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + # Pre-populate with existing tokens + _, _ = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_existing, + k_rope_existing, + q_nope_existing, + k_nope_existing, + None, + rope_ref.cos_sin_cache, + pos_ids_existing, + (ckv_cache, kpe_cache), + kv_page_indices_existing, + kv_page_indptr_existing, + batch_indices_existing, + positions_existing, + page_size=page_size, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + else: + if kv_layout == "NHD": + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + else: # HND + k_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + # Pre-populate with existing tokens + _, _ = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_existing, + k_rope_existing, + q_nope_existing, + k_nope_existing, + v_existing, + rope_ref.cos_sin_cache, + pos_ids_existing, + (k_cache, v_cache), + kv_page_indices_existing, + kv_page_indptr_existing, + batch_indices_existing, + positions_existing, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + + # Step 2: Append new tokens to the pre-populated cache + if attention_type == "mla": + q_rope_new = torch.randn( + num_new_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_new = torch.randn( + num_new_tokens, rope_dim, dtype=input_dtype, device=device + ) + k_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, no_rope_dim, dtype=input_dtype, device=device + ) + ) + v_new = None + else: + q_rope_new = torch.randn( + num_new_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_new = torch.randn( + num_new_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_kv_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + v_new = torch.randn( + num_new_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + pos_ids_new = torch.arange( + num_existing_tokens, + num_existing_tokens + num_new_tokens, + device=device, + dtype=torch.int32, + ) + + # Build metadata for new tokens (continue appending to first request) + num_pages_new_needed = (total_tokens + page_size - 1) // page_size + kv_page_indptr_new = torch.tensor( + [0, num_pages_new_needed] + [num_pages_new_needed] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indices_new = torch.arange( + num_pages_new_needed, dtype=torch.int32, device=device + ) + # For continuation, positions start at num_existing_tokens + batch_indices_new = torch.zeros(num_new_tokens, device=device, dtype=torch.int32) + positions_new = torch.arange( + num_existing_tokens, + num_existing_tokens + num_new_tokens, + device=device, + dtype=torch.int32, + ) + + # Snapshot existing cache for later comparison + if attention_type == "mla": + ckv_cache_before = ckv_cache.clone() + kpe_cache_before = kpe_cache.clone() + else: + k_cache_before = k_cache.clone() + v_cache_before = v_cache.clone() + + # Append new tokens + if attention_type == "mla": + q_rope_out_new, q_nope_out_new = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_new, + k_rope_new, + q_nope_new, + k_nope_new, + None, + rope_ref.cos_sin_cache, + pos_ids_new, + (ckv_cache, kpe_cache), + kv_page_indices_new, + kv_page_indptr_new, + batch_indices_new, + positions_new, + page_size=page_size, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + else: + q_rope_out_new, q_nope_out_new = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_new, + k_rope_new, + q_nope_new, + k_nope_new, + v_new, + rope_ref.cos_sin_cache, + pos_ids_new, + (k_cache, v_cache), + kv_page_indices_new, + kv_page_indptr_new, + batch_indices_new, + positions_new, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + + # Verify Q outputs for new tokens (handle None for no_rope_dim == 0) + q_in_new = ( + q_rope_new + if q_nope_new is None + else torch.cat([q_rope_new, q_nope_new], dim=-1) + ) + k_in_new = ( + k_rope_new + if k_nope_new is None + else torch.cat([k_rope_new, k_nope_new], dim=-1) + ) + q_out_f16_ref_new, k_out_f16_ref_new = rope_ref.forward_native( + pos_ids_new, q_in_new, k_in_new + ) + q_out_f8_ref_new = q_out_f16_ref_new.to(quant_dtype) + k_out_f8_ref_new = k_out_f16_ref_new.to(quant_dtype) + + torch.testing.assert_close( + q_out_f8_ref_new[..., :rope_dim].float(), + q_rope_out_new.float(), + rtol=2e-1, + atol=1e-2, + ) + torch.testing.assert_close( + q_out_f8_ref_new[..., rope_dim:].float(), + q_nope_out_new.float(), + rtol=2e-1, + atol=1e-2, + ) + + # FP8 tolerances + if quant_dtype == torch.float8_e4m3fn: + rtol_val, atol_val = 0.25, 0.5 + else: + rtol_val, atol_val = 0.25, 1.0 + + # Verify existing cache entries remain unchanged + if attention_type == "mla": + # Check that entries before num_existing_tokens are unchanged + for i in range(num_existing_tokens): + b = batch_indices_existing[i].item() + pos = positions_existing[i].item() + page_iter = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) // page_size + entry_idx = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) % page_size + page_idx = kv_page_indices_existing[page_iter].item() + torch.testing.assert_close( + ckv_cache[page_idx, entry_idx, :].float(), + ckv_cache_before[page_idx, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing CKV cache entry {i} was modified", + ) + torch.testing.assert_close( + kpe_cache[page_idx, entry_idx, :].float(), + kpe_cache_before[page_idx, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing KPE cache entry {i} was modified", + ) + else: + for i in range(num_existing_tokens): + b = batch_indices_existing[i].item() + pos = positions_existing[i].item() + page_iter = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) // page_size + entry_idx = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) % page_size + page_idx = kv_page_indices_existing[page_iter].item() + if kv_layout == "NHD": + torch.testing.assert_close( + k_cache[page_idx, entry_idx, :, :].float(), + k_cache_before[page_idx, entry_idx, :, :].float(), + rtol=0, + atol=0, + msg=f"Existing K cache entry {i} was modified", + ) + torch.testing.assert_close( + v_cache[page_idx, entry_idx, :, :].float(), + v_cache_before[page_idx, entry_idx, :, :].float(), + rtol=0, + atol=0, + msg=f"Existing V cache entry {i} was modified", + ) + else: # HND + torch.testing.assert_close( + k_cache[page_idx, :, entry_idx, :].float(), + k_cache_before[page_idx, :, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing K cache entry {i} was modified", + ) + torch.testing.assert_close( + v_cache[page_idx, :, entry_idx, :].float(), + v_cache_before[page_idx, :, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing V cache entry {i} was modified", + ) + + # Verify new cache entries are correct + if attention_type == "mla": + k_rope_ref_new = k_out_f8_ref_new[..., :rope_dim] + k_nope_ref_new = k_out_f8_ref_new[..., rope_dim:] + + for i in range(num_new_tokens): + b = batch_indices_new[i].item() + pos = positions_new[i].item() + page_iter = (kv_page_indptr_new[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr_new[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices_new[page_iter].item() + torch.testing.assert_close( + ckv_cache[page_idx, entry_idx, :].float(), + k_nope_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + torch.testing.assert_close( + kpe_cache[page_idx, entry_idx, :].float(), + k_rope_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + else: + quant_scale_kv = 1.0 + v_ref_tokens_new = (v_new * quant_scale_kv).to(quant_dtype) + + for i in range(num_new_tokens): + b = batch_indices_new[i].item() + pos = positions_new[i].item() + page_iter = (kv_page_indptr_new[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr_new[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices_new[page_iter].item() + if kv_layout == "NHD": + torch.testing.assert_close( + k_cache[page_idx, entry_idx, :, :].float(), + k_out_f8_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + torch.testing.assert_close( + v_cache[page_idx, entry_idx, :, :].float(), + v_ref_tokens_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + else: # HND + torch.testing.assert_close( + k_cache[page_idx, :, entry_idx, :].float(), + k_out_f8_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + torch.testing.assert_close( + v_cache[page_idx, :, entry_idx, :].float(), + v_ref_tokens_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + + @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @@ -492,6 +1383,10 @@ def test_mla_rope_quantize( enable_pdl, ): device = "cuda:0" + # Fixed seed for reproducibility across tests + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) num_qo_heads = 128 q_in = torch.randn(num_tokens, num_qo_heads, 576, dtype=input_dtype, device=device) k_in = torch.randn(num_tokens, 576, dtype=input_dtype, device=device) From a9f71bd88caca191217753171cdff1bba3212068 Mon Sep 17 00:00:00 2001 From: Lain Date: Mon, 17 Nov 2025 23:53:29 -0800 Subject: [PATCH 066/130] [API change] Allow using torch.Tensor for scales for trtllm-gen attention (#2084) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description - change `bmm1_scale` and `bmm2_scale` to `Union[float, torch.Tensor]`. notice that when using tensor, it must be applied by log2e - **remove the `bmm1_scale_log2_tensor` and `bmm2_scale_tensor` in the `xqa_batch_decode_with_kv_cache_mla`** - update trtllm-gen FMHA kernels TODO: do the same refactor for xqa kernels. The support for the device side scales was removed in #2033 ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Attention scale parameters now accept either floats or 1-element tensors across prefill, decode and runtime; tensor scales are validated and applied on-device and pointer-backed scale paths are supported. * **Chores** * Updated FMHA artifact path and checksum constants; added a public utility import and removed an obsolete inline comment. * **Tests** * Updated tests to exercise device/tensor-or-scalar scale flows, removed legacy per-tensor call-site args, and added device-scale parametrization for several test variants. --------- Signed-off-by: Siyuan Fu --- csrc/trtllm_fmha_kernel_launcher.cu | 142 +++++++++++++----- flashinfer/artifacts.py | 4 +- flashinfer/decode.py | 83 +++++----- flashinfer/prefill.py | 54 ++++--- include/flashinfer/trtllm/fmha/kernelParams.h | 1 - tests/attention/test_trtllm_gen_attention.py | 142 +++++++++++++----- tests/attention/test_trtllm_gen_mla.py | 19 --- 7 files changed, 284 insertions(+), 161 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 89d958ce7f..5c1de17bb0 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -28,6 +29,7 @@ #include "tvm_ffi_utils.h" using tvm::ffi::Optional; +using tvm::ffi::Variant; namespace flashinfer { @@ -78,9 +80,10 @@ void trtllm_paged_attention_launcher( int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, - double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, - int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sm_count, - bool enable_pdl, int64_t workspace_size, cudaStream_t stream) { + double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr, + const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, + int64_t window_left, int64_t sum_seq_q, int64_t sm_count, bool enable_pdl, + int64_t workspace_size, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -117,8 +120,12 @@ void trtllm_paged_attention_launcher( runner_params.vStrideBatch = kv_stride_batch; runner_params.mNumPagesInMemPool = num_pages_in_mem_pool; runner_params.stream = stream; + // the scaleSoftmaxLog2Ptr and outputScalePtr have higher priority than the scaleSoftmaxLog2 and + // outputScale. if they are not nullptr, then scaleSoftmaxLog2 and outputScale will be ignored runner_params.outputScale = bmm2_scale; + runner_params.outputScalePtr = bmm2_scale_ptr; runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E; + runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr; runner_params.oSfPtr = out_scale_factor; runner_params.mSfStartTokenIdx = o_sf_start_index; runner_params.mScaleSfO = o_sf_scale; @@ -197,11 +204,12 @@ inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_T void trtllm_paged_attention_decode(TensorView out, Optional out_scale_factor, TensorView query, TensorView key_cache, TensorView value_cache, TensorView workspace_buffer, TensorView block_tables, - TensorView seq_lens, int64_t max_kv_len, double bmm1_scale, - double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, - int64_t o_sf_start_index, int64_t window_left, int64_t sm_count, - bool enable_pdl, int64_t workspace_size, - Optional attention_sinks) { + TensorView seq_lens, int64_t max_kv_len, + Variant bmm1_scale, + Variant bmm2_scale, double o_sf_scale, + int64_t o_sf_vec_size, int64_t o_sf_start_index, + int64_t window_left, int64_t sm_count, bool enable_pdl, + int64_t workspace_size, Optional attention_sinks) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -250,7 +258,25 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal << "attention_sinks must be a float tensor"; attention_sinks_ptr = static_cast(attention_sinks.value().data_ptr()); } - + auto maybe_bmm1_scale_value = bmm1_scale.as(); + auto maybe_bmm2_scale_value = bmm2_scale.as(); + auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as(); + auto maybe_bmm2_scale_tensor = bmm2_scale.as(); + TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(), + "bmm1_scale must be either a double or a tensor"); + TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(), + "bmm2_scale must be either a double or a tensor"); + double bmm1_scale_value = + maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0; + double bmm2_scale_value = + maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0; + float* bmm1_scale_log2_ptr = + maybe_bmm1_scale_log2_tensor.has_value() + ? static_cast(maybe_bmm1_scale_log2_tensor.value().data_ptr()) + : nullptr; + float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() + ? static_cast(maybe_bmm2_scale_tensor.value().data_ptr()) + : nullptr; trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), @@ -259,21 +285,20 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, - kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale, - bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, - enable_pdl, workspace_size, stream); + kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, + bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, + o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, + stream); } -void trtllm_paged_attention_context(TensorView out, Optional out_scale_factor, - TensorView query, TensorView key_cache, TensorView value_cache, - TensorView workspace_buffer, TensorView block_tables, - TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, - double bmm1_scale, double bmm2_scale, double o_sf_scale, - int64_t o_sf_vec_size, int64_t o_sf_start_index, - int64_t batch_size, int64_t window_left, - TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, - int64_t sm_count, bool enable_pdl, int64_t workspace_size, - Optional attention_sinks) { +void trtllm_paged_attention_context( + TensorView out, Optional out_scale_factor, TensorView query, TensorView key_cache, + TensorView value_cache, TensorView workspace_buffer, TensorView block_tables, + TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, + Variant bmm1_scale, Variant bmm2_scale, + double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size, + int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, + bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); @@ -312,6 +337,26 @@ void trtllm_paged_attention_context(TensorView out, Optional out_sca attention_sinks_ptr = static_cast(attention_sinks.value().data_ptr()); } + auto maybe_bmm1_scale_value = bmm1_scale.as(); + auto maybe_bmm2_scale_value = bmm2_scale.as(); + auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as(); + auto maybe_bmm2_scale_tensor = bmm2_scale.as(); + TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(), + "bmm1_scale must be either a double or a tensor"); + TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(), + "bmm2_scale must be either a double or a tensor"); + double bmm1_scale_value = + maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0; + double bmm2_scale_value = + maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0; + float* bmm1_scale_log2_ptr = + maybe_bmm1_scale_log2_tensor.has_value() + ? static_cast(maybe_bmm1_scale_log2_tensor.value().data_ptr()) + : nullptr; + float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() + ? static_cast(maybe_bmm2_scale_tensor.value().data_ptr()) + : nullptr; + trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), @@ -321,8 +366,9 @@ void trtllm_paged_attention_context(TensorView out, Optional out_sca q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, - max_num_blocks_per_seq, bmm1_scale, bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, - window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream); + max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, + bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, + enable_pdl, workspace_size, stream); } void trtllm_ragged_attention_launcher( @@ -331,8 +377,9 @@ void trtllm_ragged_attention_launcher( Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, int64_t max_q_len, int64_t max_kv_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_v, int64_t sum_seq_q, int64_t sum_seq_kv, double bmm1_scale, double bmm2_scale, - double o_sf_scale, int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, - bool is_causal, int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch, + const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, + int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal, + int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch, int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch, int64_t workspace_size, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { @@ -360,8 +407,12 @@ void trtllm_ragged_attention_launcher( runner_params.mQkvLayout = QkvLayout::SeparateQkv; runner_params.mMultiProcessorCount = sm_count; runner_params.stream = stream; + // the scaleSoftmaxLog2Ptr and outputScalePtr have higher priority than the scaleSoftmaxLog2 and + // outputScale. if they are not nullptr, then scaleSoftmaxLog2 and outputScale will be ignored runner_params.outputScale = bmm2_scale; + runner_params.outputScalePtr = bmm2_scale_ptr; runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E; + runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr; runner_params.mScaleSfO = o_sf_scale; runner_params.mChunkedAttentionSize = INT_MAX; // disable chunked attention by INT_MAX runner_params.mAttentionWindowSize = @@ -414,12 +465,12 @@ void trtllm_ragged_attention_launcher( void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, TensorView value, TensorView workspace_buffer, TensorView seq_lens, int64_t max_q_len, - int64_t max_kv_len, double bmm1_scale, double bmm2_scale, - double o_sf_scale, int64_t batch_size, int64_t window_left, - TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, - int64_t sm_count, bool enable_pdl, bool is_causal, - int64_t workspace_size, Optional attention_sinks, - Optional lse) { + int64_t max_kv_len, Variant bmm1_scale, + Variant bmm2_scale, double o_sf_scale, + int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q, + TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, + bool is_causal, int64_t workspace_size, + Optional attention_sinks, Optional lse) { float* attention_sinks_ptr = nullptr; if (attention_sinks.has_value()) { TVM_FFI_ICHECK_EQ(attention_sinks.value().dtype(), dl_float32) @@ -453,15 +504,34 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T int v_stride_heads = value.stride(1); int v_stride_batch = value.numel(); + auto maybe_bmm1_scale_value = bmm1_scale.as(); + auto maybe_bmm2_scale_value = bmm2_scale.as(); + auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as(); + auto maybe_bmm2_scale_tensor = bmm2_scale.as(); + TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(), + "bmm1_scale must be either a double or a tensor"); + TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(), + "bmm2_scale must be either a double or a tensor"); + double bmm1_scale_value = + maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0; + double bmm2_scale_value = + maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0; + float* bmm1_scale_log2_ptr = + maybe_bmm1_scale_log2_tensor.has_value() + ? static_cast(maybe_bmm1_scale_log2_tensor.value().data_ptr()) + : nullptr; + float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() + ? static_cast(maybe_bmm2_scale_tensor.value().data_ptr()) + : nullptr; trtllm_ragged_attention_launcher( out.data_ptr(), query.data_ptr(), key.data_ptr(), value.data_ptr(), workspace_buffer.data_ptr(), static_cast(seq_lens.data_ptr()), static_cast(cum_seq_lens_q.data_ptr()), static_cast(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, o_data_type, max_q_len, max_kv_len, - num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale, - bmm2_scale, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal, - k_stride_keys_values, k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads, - v_stride_batch, workspace_size, stream); + num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value, + bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left, + sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch, + v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream); } namespace trtllm_cubin_loader { diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 60853ecd20..cfb2862e47 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -87,7 +87,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "b793e1b2cf7c419f070372ba55bbe53ca6fb9016/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "1e49deb33ec20018ae0acf1d956a579578069da1/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988" ) @@ -107,7 +107,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "20c017db0761a30130f05080ed2078f6c8044c0c2b3be7c4353ec740034b4432" + "66757498f573430583d63b04c02bf9e38306eefe2ce31df9b5d923d99bd15d84" ) TRTLLM_GEN_BMM: str = ( "85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf" diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 5826e743da..4f4e8b0215 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -42,6 +42,7 @@ get_single_prefill_module, ) from .utils import ( + log2e, FP4Tensor, MaskMode, PosEncodingMode, @@ -1880,8 +1881,8 @@ def _paged_run( block_tables: torch.Tensor, seq_lens: torch.Tensor, max_seq_len: int, - bmm1_scale: float, # todo(Yingyi): add dynamic scale tensor later - bmm2_scale: float, + bmm1_scale: Union[float, torch.Tensor], + bmm2_scale: Union[float, torch.Tensor], workspace_size: int, window_left: int = -1, enable_pdl: bool = None, @@ -1893,12 +1894,11 @@ def _paged_run( if self._sm_count is None: self._sm_count = get_device_sm_count(query.device) - bmm1_scale = ( - bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale - ) - bmm2_scale = ( - bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale - ) + if isinstance(bmm1_scale, torch.Tensor): + assert bmm1_scale.dtype == torch.float32 + bmm1_scale = bmm1_scale * log2e + if isinstance(bmm2_scale, torch.Tensor): + assert bmm2_scale.dtype == torch.float32 self._op.trtllm_paged_attention_decode( out, @@ -2066,8 +2066,8 @@ def trtllm_batch_decode_with_kv_cache( block_tables: torch.Tensor, seq_lens: torch.Tensor, max_seq_len: int, - bmm1_scale: float, - bmm2_scale: float, # todo(Yingyi): add dynamic scale tensor later + bmm1_scale: Union[float, torch.Tensor] = 1.0, + bmm2_scale: Union[float, torch.Tensor] = 1.0, window_left: int = -1, out: Optional[Union[torch.Tensor, FP4Tensor]] = None, out_dtype: Optional[Union[torch.dtype, str]] = None, @@ -2105,11 +2105,13 @@ def trtllm_batch_decode_with_kv_cache( max_seq_len : int max sequence length for kv_cache - bmm1_scale : float + bmm1_scale : Union[float, torch.Tensor] fused scale for bmm1 input. + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. - bmm2_scale : float + bmm2_scale : Union[float, torch.Tensor] fused scale for bmm2 input. + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. window_left : int = -1 The left (inclusive) window size for the attention window, when set to ``-1``, the window @@ -2173,6 +2175,11 @@ def trtllm_batch_decode_with_kv_cache( ) if backend == "xqa": + # TODO(Siyuan): support device scale factors, which was removed in #2033 + if not isinstance(bmm1_scale, float): + bmm1_scale = bmm1_scale.item() + if not isinstance(bmm2_scale, float): + bmm2_scale = bmm2_scale.item() # xqa backend doesn't support nvfp4 output if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): raise ValueError("xqa backend does not support nvfp4 output") @@ -2287,12 +2294,11 @@ def trtllm_batch_decode_with_kv_cache( else: raise ValueError(f"Invalid out_dtype: {out_dtype}") - bmm1_scale = ( - bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale - ) - bmm2_scale = ( - bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale - ) + if isinstance(bmm1_scale, torch.Tensor): + assert bmm1_scale.dtype == torch.float32 + bmm1_scale = bmm1_scale * log2e + if isinstance(bmm2_scale, torch.Tensor): + assert bmm2_scale.dtype == torch.float32 run_func( out, @@ -2533,10 +2539,8 @@ def trtllm_batch_decode_with_kv_cache_mla( seq_lens: torch.Tensor, max_seq_len: int, out: Optional[torch.Tensor] = None, - bmm1_scale: Optional[float] = 1.0, - bmm2_scale: Optional[float] = 1.0, - bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, - bmm2_scale_tensor: Optional[torch.Tensor] = None, + bmm1_scale: Union[float, torch.Tensor] = 1.0, + bmm2_scale: Union[float, torch.Tensor] = 1.0, sinks: Optional[List[torch.Tensor]] = None, enable_pdl: bool = None, backend: str = "auto", @@ -2554,9 +2558,9 @@ def trtllm_batch_decode_with_kv_cache_mla( max_seq_len: max sequence length for kv_cache out: output tensor, if not provided, will be allocated internally bmm1_scale: fused scale for mla bmm1 input. + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. bmm2_scale: fused scale for mla bmm2 input. - bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in. - bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input. + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. sinks: additional value per head in the denominator of the softmax. backend : str = "auto" The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. @@ -2569,8 +2573,8 @@ def trtllm_batch_decode_with_kv_cache_mla( bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) bmm2_scale = v_scale * o_scale or, - bmm1_scale_log2_tensor = [q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) * M_LOG2E] - bmm2_scale_tensor = [v_scale * o_scale] + bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) + bmm2_scale = torch.Tensor([v_scale * o_scale]) The two scale factors should be static constant for cuda graph capture. Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. @@ -2587,6 +2591,11 @@ def trtllm_batch_decode_with_kv_cache_mla( "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" ) if backend == "xqa": + # TODO(Siyuan): support device scale factors, which was removed in #2033 + if not isinstance(bmm1_scale, float): + bmm1_scale = bmm1_scale.item() + if not isinstance(bmm2_scale, float): + bmm2_scale = bmm2_scale.item() if ( get_compute_capability(query.device)[0] != 12 or query.dtype != torch.float8_e4m3fn @@ -2653,15 +2662,11 @@ def trtllm_batch_decode_with_kv_cache_mla( "out", ) - if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: - # dynamic scale factors - if ( - query.dtype != torch.float8_e4m3fn - or kv_cache.dtype != torch.float8_e4m3fn - ): - raise ValueError( - "Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation" - ) + if isinstance(bmm1_scale, torch.Tensor): + assert bmm1_scale.dtype == torch.float32 + bmm1_scale = bmm1_scale * log2e + if isinstance(bmm2_scale, torch.Tensor): + assert bmm2_scale.dtype == torch.float32 run_func( out, @@ -2701,10 +2706,9 @@ def xqa_batch_decode_with_kv_cache_mla( seq_lens: torch.Tensor, max_seq_len: int, out: Optional[torch.Tensor] = None, + # TODO(Siyuan): support device scale factors, which was removed in #2033 bmm1_scale: Optional[float] = 1.0, bmm2_scale: Optional[float] = 1.0, - bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, - bmm2_scale_tensor: Optional[torch.Tensor] = None, sinks: Optional[List[torch.Tensor]] = None, enable_pdl: bool = None, ) -> torch.Tensor: @@ -2722,17 +2726,12 @@ def xqa_batch_decode_with_kv_cache_mla( out: output tensor, if not provided, will be allocated internally bmm1_scale: fused scale for mla bmm1 input. bmm2_scale: fused scale for mla bmm2 input. - bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in. - bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input. sinks: additional value per head in the denominator of the softmax. Note: In MLA, the actual BMM1 and BMM2 scales applied would be fused as: bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) bmm2_scale = v_scale * o_scale - or, - bmm1_scale_log2_tensor = [q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) * M_LOG2E] - bmm2_scale_tensor = [v_scale * o_scale] The two scale factors should be static constant for cuda graph capture. Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 6b4353011f..47d725c5d3 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -36,6 +36,7 @@ from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens from .quantization import packbits, segment_packbits from .utils import ( + log2e, FP4Tensor, MaskMode, PosEncodingMode, @@ -190,8 +191,8 @@ def _paged_run( seq_lens: torch.Tensor, max_q_len: int, max_kv_len: int, - bmm1_scale: float, - bmm2_scale: float, + bmm1_scale: Union[float, torch.Tensor], + bmm2_scale: Union[float, torch.Tensor], batch_size: int, cum_seq_lens_q: torch.Tensor, cum_seq_lens_kv: torch.Tensor, @@ -204,12 +205,11 @@ def _paged_run( sm_count = get_device_sm_count(query.device) if out is None: out = torch.empty_like(query) - bmm1_scale = ( - bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale - ) - bmm2_scale = ( - bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale - ) + if isinstance(bmm1_scale, torch.Tensor): + assert bmm1_scale.dtype == torch.float32 + bmm1_scale = bmm1_scale * log2e + if isinstance(bmm2_scale, torch.Tensor): + assert bmm2_scale.dtype == torch.float32 op.trtllm_paged_attention_context( out, None, # fp4 output not supported in wrapper api yet. @@ -3201,8 +3201,8 @@ def trtllm_ragged_attention_deepseek( seq_lens: torch.Tensor, max_q_len: int, max_kv_len: int, - bmm1_scale: float, - bmm2_scale: float, + bmm1_scale: Union[float, torch.Tensor], + bmm2_scale: Union[float, torch.Tensor], o_sf_scale: float, batch_size: int, window_left: int, @@ -3232,10 +3232,12 @@ def trtllm_ragged_attention_deepseek( max query length max_kv_len : int max key/value length - bmm1_scale : float + bmm1_scale : Union[float, torch.Tensor] scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5) - bmm2_scale : float + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. + bmm2_scale : Union[float, torch.Tensor] scale for bmm2, scale_v + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. o_sf_scale : float scale for output batch_size : int @@ -3289,6 +3291,12 @@ def trtllm_ragged_attention_deepseek( dtype=torch.float32, ) + if isinstance(bmm1_scale, torch.Tensor): + assert bmm1_scale.dtype == torch.float32 + bmm1_scale = bmm1_scale * log2e + if isinstance(bmm2_scale, torch.Tensor): + assert bmm2_scale.dtype == torch.float32 + workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() run_func( out, @@ -3327,8 +3335,8 @@ def trtllm_batch_context_with_kv_cache( seq_lens: torch.Tensor, max_q_len: int, max_kv_len: int, - bmm1_scale: float, - bmm2_scale: float, + bmm1_scale: Union[float, torch.Tensor], + bmm2_scale: Union[float, torch.Tensor], batch_size: int, cum_seq_lens_q: torch.Tensor, cum_seq_lens_kv: torch.Tensor, @@ -3362,10 +3370,12 @@ def trtllm_batch_context_with_kv_cache( max sequence length for query max_kv_len : int max sequence length for kv_cache - bmm1_scale : float + bmm1_scale : Union[float, torch.Tensor] fused scale for bmm1 input. - bmm2_scale : float + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. + bmm2_scale : Union[float, torch.Tensor] fused scale for bmm2 input. + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. batch_size : int batch size cum_seq_lens_q : torch.Tensor @@ -3494,13 +3504,11 @@ def trtllm_batch_context_with_kv_cache( else: raise ValueError(f"Invalid out_dtype: {out_dtype}") - bmm1_scale = ( - bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale - ) - bmm2_scale = ( - bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale - ) - + if isinstance(bmm1_scale, torch.Tensor): + assert bmm1_scale.dtype == torch.float32 + bmm1_scale = bmm1_scale * log2e + if isinstance(bmm2_scale, torch.Tensor): + assert bmm2_scale.dtype == torch.float32 workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() run_func( out, diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 533b98c9e0..c184ad9e10 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -715,7 +715,6 @@ struct KernelParams { params.mNumHeadsKv = options.mNumHeadsKv; params.mNumHeadsQPerKv = options.mNumHeadsQPerKv; params.mNumHiddenEltsO = options.mNumHeadsQ * options.mHeadDimQk; - // todo(Yingyi): might take a scalar tensor later params.mOutputScale = options.outputScale; params.mScaleSoftmaxLog2 = options.scaleSoftmaxLog2; params.mStartTokenIdxSfO = options.mSfStartTokenIdx; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 0d80e9cf90..642c437e59 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -339,39 +339,7 @@ def unpack_compare_nvfp4( return output_unpacked, output_ref -@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) -@pytest.mark.parametrize( - "batch_size,page_size,num_kv_heads,head_grp_size", - [ - (4, 16, 2, 1), - (4, 32, 4, 5), - (4, 64, 4, 8), - (128, 16, 2, 5), - (128, 32, 4, 1), - (128, 64, 2, 8), - (256, 16, 4, 8), - (256, 32, 2, 8), - (256, 64, 4, 1), - (256, 64, 4, 5), - ], -) -@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left -@pytest.mark.parametrize( - "q_dtype,kv_dtype,o_dtype", - [ - ("bf16", "bf16", "bf16"), - ("fp16", "fp16", "fp16"), - ("fp8", "fp8", "bf16"), - ("fp8", "fp8", "fp16"), - ("fp8", "fp8", "fp8"), - ("fp8", "fp8", "nvfp4"), - ], -) -@pytest.mark.parametrize("enable_pdl", [True, False, None]) -@pytest.mark.parametrize("enable_sink", [True, False]) -@pytest.mark.parametrize("max_q_len", [511]) -@pytest.mark.parametrize("max_kv_len", [2047]) -def test_trtllm_batch_prefill( +def _test_trtllm_batch_prefill( kv_layout, batch_size, page_size, @@ -385,6 +353,7 @@ def test_trtllm_batch_prefill( enable_sink, max_q_len, max_kv_len, + device_scale, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: @@ -485,6 +454,16 @@ def test_trtllm_batch_prefill( ) # Run trtllm-gen function call + bmm1_scale = q_scale * k_scale * sm_scale + bmm2_scale = v_scale / o_scale + if isinstance(bmm1_scale, torch.Tensor) and not device_scale: + bmm1_scale = bmm1_scale.item() + elif not isinstance(bmm1_scale, torch.Tensor) and device_scale: + bmm1_scale = torch.tensor(bmm1_scale, device=GPU_DEVICE, dtype=torch.float32) + if isinstance(bmm2_scale, torch.Tensor) and not device_scale: + bmm2_scale = bmm2_scale.item() + elif not isinstance(bmm2_scale, torch.Tensor) and device_scale: + bmm2_scale = torch.tensor(bmm2_scale, device=GPU_DEVICE, dtype=torch.float32) output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q.contiguous(), kv_cache, @@ -493,8 +472,8 @@ def test_trtllm_batch_prefill( seq_lens.to(GPU_DEVICE), torch.max(q_lens).item(), torch.max(seq_lens).item(), - q_scale * k_scale * sm_scale, # bmm1_scale - v_scale / o_scale, # bmm2_scale + bmm1_scale, # bmm1_scale + bmm2_scale, # bmm2_scale batch_size, q_indptr, kv_indptr, @@ -568,6 +547,71 @@ def test_trtllm_batch_prefill( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize( + "batch_size,page_size,num_kv_heads,head_grp_size", + [ + (4, 16, 2, 1), + (4, 32, 4, 5), + (4, 64, 4, 8), + (128, 16, 2, 5), + (128, 32, 4, 1), + (128, 64, 2, 8), + (256, 16, 4, 8), + (256, 32, 2, 8), + (256, 64, 4, 1), + (256, 64, 4, 5), + ], +) +@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("fp8", "fp8", "bf16"), + ("fp8", "fp8", "fp16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("max_q_len", [511]) +@pytest.mark.parametrize("max_kv_len", [2047]) +def test_trtllm_batch_prefill( + kv_layout, + batch_size, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_q_len, + max_kv_len, +): + _test_trtllm_batch_prefill( + kv_layout, + batch_size, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_q_len, + max_kv_len, + kv_dtype == "fp8", + ) + + @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize( "batch_size,page_size,num_kv_heads,head_grp_size", @@ -601,7 +645,7 @@ def test_trtllm_batch_prefill_bs1( max_q_len, max_kv_len, ): - test_trtllm_batch_prefill( + _test_trtllm_batch_prefill( kv_layout, batch_size, page_size, @@ -615,6 +659,7 @@ def test_trtllm_batch_prefill_bs1( enable_sink, max_q_len, max_kv_len, + False, ) @@ -634,6 +679,7 @@ def _test_trtllm_batch_decode( enable_sink, max_in_kv_len, head_dim, + device_scale=False, ): """ Common function for testing trtllm-gen decode. @@ -781,6 +827,16 @@ def _test_trtllm_batch_decode( ) # Run decode function call with specified backend + bmm1_scale = q_scale * k_scale * sm_scale + bmm2_scale = v_scale / o_scale + if isinstance(bmm1_scale, torch.Tensor) and not device_scale: + bmm1_scale = bmm1_scale.item() + elif not isinstance(bmm1_scale, torch.Tensor) and device_scale: + bmm1_scale = torch.tensor(bmm1_scale, device=GPU_DEVICE, dtype=torch.float32) + if isinstance(bmm2_scale, torch.Tensor) and not device_scale: + bmm2_scale = bmm2_scale.item() + elif not isinstance(bmm2_scale, torch.Tensor) and device_scale: + bmm2_scale = torch.tensor(bmm2_scale, device=GPU_DEVICE, dtype=torch.float32) output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( q.contiguous(), kv_cache, @@ -788,8 +844,8 @@ def _test_trtllm_batch_decode( page_table, seq_lens.to(GPU_DEVICE), torch.max(seq_lens).item(), - q_scale * k_scale * sm_scale, # bmm1_scale - v_scale / o_scale, # bmm2_scale + bmm1_scale, + bmm2_scale, window_left, # window_left out=out, out_dtype=out_dtype, @@ -976,6 +1032,7 @@ def test_trtllm_batch_decode( enable_sink, max_in_kv_len, head_dim, + kv_dtype == "fp8", ) @@ -997,6 +1054,7 @@ def test_trtllm_batch_decode( @pytest.mark.parametrize("enable_sink", [False]) @pytest.mark.parametrize("max_in_kv_len", [8192]) @pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("device_scale", [True, False]) def test_trtllm_batch_decode_bs1( kv_layout, batch_size, @@ -1012,6 +1070,7 @@ def test_trtllm_batch_decode_bs1( enable_sink, max_in_kv_len, head_dim, + device_scale, ): # Small number of test cases for batch size 1 pytest.xfail("trtllm-gen decode gets incorrect output with bs1") @@ -1031,6 +1090,7 @@ def test_trtllm_batch_decode_bs1( enable_sink, max_in_kv_len, head_dim, + device_scale, ) @@ -1063,6 +1123,7 @@ def test_trtllm_batch_decode_bs1( @pytest.mark.parametrize("enable_sink", [False]) @pytest.mark.parametrize("max_in_kv_len", [110]) @pytest.mark.parametrize("head_dim", [256]) +@pytest.mark.parametrize("device_scale", [True, False]) def test_trtllm_batch_decode_head_dim_256( kv_layout, batch_size, @@ -1078,6 +1139,7 @@ def test_trtllm_batch_decode_head_dim_256( enable_sink, max_in_kv_len, head_dim, + device_scale, ): # Small number of test cases for head_dim = 256 pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") @@ -1097,6 +1159,7 @@ def test_trtllm_batch_decode_head_dim_256( enable_sink, max_in_kv_len, head_dim, + device_scale, ) @@ -1122,6 +1185,7 @@ def test_trtllm_batch_decode_head_dim_256( @pytest.mark.parametrize("enable_sink", [False]) @pytest.mark.parametrize("max_in_kv_len", [4096, 8192, 16384, 32768, 65536, 131072]) @pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("device_scale", [True, False]) def test_trtllm_batch_decode_long_sequence_length( kv_layout, batch_size, @@ -1137,6 +1201,7 @@ def test_trtllm_batch_decode_long_sequence_length( enable_sink, max_in_kv_len, head_dim, + device_scale, ): # Small number of test cases for long sequence length _test_trtllm_batch_decode( @@ -1155,6 +1220,7 @@ def test_trtllm_batch_decode_long_sequence_length( enable_sink, max_in_kv_len, head_dim, + device_scale, ) diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 999eda2a8a..508fce831d 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -1,5 +1,3 @@ -import math - import pytest import torch @@ -123,21 +121,6 @@ def test_trtllm_batch_decode_mla( workspace_buffer = global_trtllm_gen_fmha_workspace_buffer workspace_buffer_ref = global_workspace_buffer - bmm1_log2_scale_tensor = ( - torch.tensor( - [scale / ((128 + 64) ** 0.5 * math.log2(math.e))], - dtype=torch.float32, - device=device, - ) - if dynamic_scale - else None - ) - bmm2_scale_tensor = ( - torch.tensor([1.0], dtype=torch.float32, device=device) - if dynamic_scale - else None - ) - # Run decode-MLA output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, @@ -151,8 +134,6 @@ def test_trtllm_batch_decode_mla( max_seq_len=max_seq_len, bmm1_scale=scale / ((128 + 64) ** 0.5), bmm2_scale=1.0, - bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, - bmm2_scale_tensor=bmm2_scale_tensor, enable_pdl=enable_pdl, backend=backend, ) From 875403e9f294736163462601c0c949765140aefe Mon Sep 17 00:00:00 2001 From: eigen <52445717+yyihuang@users.noreply.github.com> Date: Tue, 18 Nov 2025 09:44:15 -0500 Subject: [PATCH 067/130] refactor: update dpsk fused_moe test [2] (#2097) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Add shuffling and blockmajorK layout in dpskv3 fused_moe fp8_blockscaled tests. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Expanded MoE test suite with per-expert weight shuffling, optional block-layout conversion, selectable weight-processing modes, and dynamic kernel flags. * Added a reference FP8 block-scale validation path and centralized accuracy checks for clearer correctness verification. * **Refactor** * Centralized test utilities: quantization mode and test-skip logic moved into shared helpers for consistent gating across MoE tests. --------- Co-authored-by: Zihao Ye --- tests/moe/test_dpsk_fused_moe_fp8.py | 396 ++++++++++++++++++------- tests/moe/test_trtllm_gen_fused_moe.py | 75 +---- tests/moe/test_utils.py | 97 ++++++ 3 files changed, 392 insertions(+), 176 deletions(-) create mode 100644 tests/moe/test_utils.py diff --git a/tests/moe/test_dpsk_fused_moe_fp8.py b/tests/moe/test_dpsk_fused_moe_fp8.py index 3ac4055128..a472ecc5a0 100644 --- a/tests/moe/test_dpsk_fused_moe_fp8.py +++ b/tests/moe/test_dpsk_fused_moe_fp8.py @@ -1,45 +1,32 @@ import pytest import torch -from flashinfer.fused_moe import trtllm_fp8_block_scale_moe, WeightLayout +from flashinfer import shuffle_matrix_a +from flashinfer.fused_moe.core import convert_to_block_layout from flashinfer.autotuner import autotune +from flashinfer.fused_moe import ( + WeightLayout, + trtllm_fp8_block_scale_moe, +) +from .test_utils import skip_checks, QuantMode +from flashinfer import GatedActType -def run( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, +def dequant_fp8_block_scaled( + intermediate_size: int, + hidden_size: int, hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor, gemm1_weights: torch.Tensor, gemm1_weights_scale: torch.Tensor, gemm2_weights: torch.Tensor, gemm2_weights_scale: torch.Tensor, - local_expert_offset: int, - routed_scaling_factor: float, - hidden_size: int, - intermediate_size: int, + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, num_experts_global: int, num_local_experts: int, - top_k: int, - n_group: int, - topk_group: int, ): - """ - - FP8 block-scale dequantization: float โ‰ˆ fp8 * scale - - DeepSeek-V3 no-aux routing: - s = sigmoid(logits) - s_with_bias = s + bias - group by n_group=8; per group take top-2 sum โ†’ pick topk_group=4 groups - on the kept groups, take global top_k=8 experts - combine with weights derived from s (without bias), normalized and - scaled by routed_scaling_factor - - Local computation: - only experts in [local_expert_offset, local_expert_offset + E_local) are - computed on this rank (GEMM1 โ†’ SwiGLU โ†’ GEMM2), then per-token weighted - accumulation. - """ - - # Fixed DeepSeek-V3/R1 geometry - H = hidden_size # deepseek v3: 7168 + # FP8 block-scale dequantization: float โ‰ˆ fp8 * scale + H = hidden_size I = intermediate_size # deepseek v3: 2048 E_local = gemm1_weights.shape[0] @@ -50,11 +37,6 @@ def run( assert E_global == num_experts_global, "num_experts_global shape mismatch" assert E_local == num_local_experts, "num_local_experts shape mismatch" - # Routing constants - TOP_K = top_k # deepseek v3: 8 - N_GROUP = n_group # deepseek v3: 8 - TOPK_GROUP = topk_group # deepseek v3: 4 - # Block counts num_hidden_blocks = H // BLOCK # 56 num_intermediate_blocks = I // BLOCK # 16 @@ -77,9 +59,6 @@ def run( ) assert routing_bias.shape[-1] == E_global - device = hidden_states.device - - # 1) FP8 block-scale dequantization # hidden_states: [T, H], scale: [H/128, T] (transposed layout) A_fp32 = hidden_states.to(torch.float32) A_scale = hidden_states_scale.to(torch.float32) # [H/128, T] @@ -106,6 +85,52 @@ def run( S2_expanded = torch.repeat_interleave(S2_expanded, BLOCK, dim=2) # [E, H, I] W2 = W2_fp32 * S2_expanded # [E, H, I] float32 + return A, W13, W2 + + +def _deepseek_moe_core( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + local_expert_offset: int, + routed_scaling_factor: float, + intermediate_size: int, + num_experts_global: int, + num_local_experts: int, + top_k: int, + n_group: int, + topk_group: int, + hidden_size: int, + A: torch.Tensor, + W13: torch.Tensor, + W2: torch.Tensor, +): + """ + - DeepSeek-V3 no-aux routing: + s = sigmoid(logits) + s_with_bias = s + bias + group by n_group=8; per group take top-2 sum โ†’ pick topk_group=4 groups + on the kept groups, take global top_k=8 experts + combine with weights derived from s (without bias), normalized and + scaled by routed_scaling_factor + - Local computation: + only experts in [local_expert_offset, local_expert_offset + E_local) are + computed on this rank (GEMM1 โ†’ SwiGLU โ†’ GEMM2), then per-token weighted + accumulation. + """ + + # Routing constants + TOP_K = top_k # deepseek v3: 8 + N_GROUP = n_group # deepseek v3: 8 + TOPK_GROUP = topk_group # deepseek v3: 4 + + I = intermediate_size # deepseek v3: 2048 + H = hidden_size # deepseek v3: 7168 + E_local = num_local_experts + E_global = num_experts_global + T = routing_logits.shape[0] + + device = A.device + # 2) No-aux routing logits = routing_logits.to(torch.float32) # [T, E_global] bias = routing_bias.to(torch.float32).reshape(-1) # [E_global] @@ -190,6 +215,70 @@ def run( return output.to(torch.bfloat16) +def run_fp8_block_scale_moe_reference( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + hidden_states: torch.Tensor, + hidden_states_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + local_expert_offset: int, + routed_scaling_factor: float, + intermediate_size: int, + num_experts_global: int, + num_local_experts: int, + top_k: int, + n_group: int, + topk_group: int, + hidden_size: int, +): + I = intermediate_size # deepseek v3: 2048 + E_local = gemm1_weights.shape[0] + H = hidden_size # deepseek v3: 7168 + assert E_local == num_local_experts, "num_local_experts shape mismatch" + + E_global = routing_logits.shape[1] + assert E_global == num_experts_global, "num_experts_global shape mismatch" + + # FP8 block-scale dequantization + A, W13, W2 = dequant_fp8_block_scaled( + hidden_size=H, + intermediate_size=I, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + routing_logits=routing_logits, + routing_bias=routing_bias, + num_experts_global=E_global, + num_local_experts=E_local, + ) + + # DeepSeek-V3 no-aux routing + output = _deepseek_moe_core( + A=A, + W13=W13, + W2=W2, + routing_logits=routing_logits, + routing_bias=routing_bias, + num_experts_global=E_global, + num_local_experts=E_local, + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + hidden_size=H, + intermediate_size=I, + local_expert_offset=local_expert_offset, + routed_scaling_factor=routed_scaling_factor, + ) + + return output + + # ----------------------------- # Helpers: FP8 block quantization (dequant scale semantics) # ----------------------------- @@ -334,6 +423,67 @@ def generate_random_inputs_moe( } +def stats_accuracy( + ref_out: torch.Tensor, + fi_out: torch.Tensor, + atol: float = 1e-1, + rtol: float = 2e-1, + percent: float = 0.85, +): + H = ref_out.shape[1] + assert H == 7168 + + # Compare + ref_f32 = ref_out.float() + fi_f32 = fi_out.float() + + abs_diff = (ref_f32 - fi_f32).abs() + rel_diff = abs_diff / (fi_f32.abs() + 1e-8) + + print("\nComparison stats:") + print(f"Max abs diff: {abs_diff.max().item():.6e}") + print(f"Mean abs diff: {abs_diff.mean().item():.6e}") + print(f"Max rel diff: {rel_diff.max().item():.6e}") + print(f"Mean rel diff: {rel_diff.mean().item():.6e}") + + # Cosine similarity and MSE + cos_sim = torch.nn.functional.cosine_similarity( + ref_f32.flatten(), fi_f32.flatten(), dim=0 + ).item() + mse = torch.mean((ref_f32 - fi_f32) ** 2).item() + print(f"Cosine similarity: {cos_sim:.6f}") + print(f"MSE: {mse:.6e}") + + # Strict allclose + allclose = torch.allclose(ref_f32, fi_f32, atol=atol, rtol=rtol) + print(f"\nAllclose(atol={atol}, rtol={rtol}): {allclose}") + + if not allclose: + # Show top-5 largest absolute errors + flat = abs_diff.flatten() + k = min(5, flat.numel()) + topv, topi = torch.topk(flat, k) + print("\nTop-5 absolute error locations:") + for rank in range(k): + idx = topi[rank].item() + t = idx // H + h = idx % H + print( + f" [t={t}, h={h}]: ref={ref_f32.flatten()[idx].item():.6e}, " + f"fi={fi_f32.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}" + ) + + left = (ref_f32 - fi_f32).abs() + right = atol + rtol * fi_f32.abs() + ok = left <= right + hit_ratio = ok.float().mean().item() + print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {percent * 100:.2f}%)") + + assert hit_ratio >= percent, ( + f"Hit ratio {hit_ratio * 100:.2f}% is less than required {percent * 100:.2f}%" + ) + + # Max num tokens to tune for trtllm-gen fused moe TUNE_MAX_NUM_TOKENS = 4096 @@ -399,29 +549,79 @@ def generate_random_inputs_moe( ], ) @pytest.mark.parametrize("enable_pdl", [True, False]) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + }, + id="NoShuffle_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + }, + id="Shuffled_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.BlockMajorK, + }, + id="Shuffled_BlockMajorK", + ), + ], +) def test_correctness_dpsk_fp8_fused_moe( - seq_len, - local_expert_offset, - use_bias, - intermediate_size, - routing_config, - enable_pdl, + seq_len: int, + local_expert_offset: int, + use_bias: bool, + intermediate_size: int, + routing_config: dict, + enable_pdl: bool, + weight_processing: dict, atol: float = 1e-1, rtol: float = 2e-1, percent: float = 0.85, ): - compatible_intermediate_size = routing_config["compatible_intermediate_size"] - if intermediate_size not in compatible_intermediate_size: - pytest.skip( - f"Intermediate size {intermediate_size} is not compatible with routing config {routing_config}" - ) - if not torch.cuda.is_available(): pytest.skip("CUDA not available") if trtllm_fp8_block_scale_moe is None: pytest.skip("flashinfer fused_moe kernel not available") + # Create a mock MoE implementation for skip_checks + class FP8BlockScaleMoe: + def __init__(self): + self.name = "FP8BlockScale" + self.quant_mode = QuantMode.FP8_BLOCK_SCALE + + moe_impl = FP8BlockScaleMoe() + + # Make copies of config dicts to avoid modifying the original parametrize values + routing_config = dict(routing_config) + weight_processing = dict(weight_processing) + + # Ensure they have compatible_moe_impls + if "compatible_moe_impls" not in routing_config: + routing_config["compatible_moe_impls"] = [type(moe_impl)] + if "compatible_moe_impls" not in weight_processing: + weight_processing["compatible_moe_impls"] = [type(moe_impl)] + + # Use the complete skip_checks function from test_utils + skip_checks( + moe_impl=moe_impl, + routing_config=routing_config, + weight_processing=weight_processing, + gated_act_type=GatedActType.SwiGlu, + num_tokens=seq_len, + hidden_size=7168, # DeepSeek-V3 hidden size + intermediate_size=intermediate_size, + ) + device = "cuda" torch.manual_seed(42) @@ -453,7 +653,7 @@ def test_correctness_dpsk_fp8_fused_moe( ) # Run reference (returns bf16) - ref_out = run( + ref_out = run_fp8_block_scale_moe_reference( routing_logits=inputs["routing_logits"], routing_bias=inputs["routing_bias"], hidden_states=inputs["hidden_states"], @@ -473,6 +673,42 @@ def test_correctness_dpsk_fp8_fused_moe( topk_group=TOPK_GROUP, ) + # Prepare weights based on weight_processing configuration + use_shuffled_weight = weight_processing["use_shuffled_weight"] + weight_layout = weight_processing["layout"] + + gemm1_weights = inputs["gemm1_weights"] + gemm2_weights = inputs["gemm2_weights"] + + if use_shuffled_weight: + # Apply weight shuffling similar to the trtllm_gen_fused_moe test + epilogue_tile_m = ( + 64 # todo(yingyi): FIXME: this depends on the kernel internals + ) + + gemm1_weights_shuffled = [] + gemm2_weights_shuffled = [] + + for i in range(E_LOCAL): + # Shuffle weights for better performance + tmp_weights1 = shuffle_matrix_a( + gemm1_weights[i].view(torch.uint8), epilogue_tile_m + ) + tmp_weights2 = shuffle_matrix_a( + gemm2_weights[i].view(torch.uint8), epilogue_tile_m + ) + + if weight_layout == WeightLayout.BlockMajorK: + block_k = 128 + tmp_weights1 = convert_to_block_layout(tmp_weights1, block_k) + tmp_weights2 = convert_to_block_layout(tmp_weights2, block_k) + + gemm1_weights_shuffled.append(tmp_weights1) + gemm2_weights_shuffled.append(tmp_weights2) + + gemm1_weights = torch.stack(gemm1_weights_shuffled).view(torch.float8_e4m3fn) + gemm2_weights = torch.stack(gemm2_weights_shuffled).view(torch.float8_e4m3fn) + # Run FlashInfer fused kernel with autotune(routing_config["enable_autotune"]): fi_out = trtllm_fp8_block_scale_moe( @@ -480,9 +716,9 @@ def test_correctness_dpsk_fp8_fused_moe( inputs["routing_bias"], # bf16 inputs["hidden_states"], # fp8 inputs["hidden_states_scale"], # [H/128, T] - inputs["gemm1_weights"], # fp8 + gemm1_weights, # fp8 (potentially shuffled) inputs["gemm1_weights_scale"].to(torch.float32), - inputs["gemm2_weights"], # fp8 + gemm2_weights, # fp8 (potentially shuffled) inputs["gemm2_weights_scale"].to(torch.float32), E_GLOBAL, TOP_K, @@ -493,61 +729,13 @@ def test_correctness_dpsk_fp8_fused_moe( inputs["local_num_experts"], inputs["routed_scaling_factor"], routing_method_type=2, # DeepSeek-styled - use_shuffled_weight=False, - weight_layout=WeightLayout.MajorK.value, + use_shuffled_weight=use_shuffled_weight, + weight_layout=weight_layout, enable_pdl=enable_pdl, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, ) - # Compare - ref_f32 = ref_out.float() - fi_f32 = fi_out.float() - - abs_diff = (ref_f32 - fi_f32).abs() - rel_diff = abs_diff / (fi_f32.abs() + 1e-8) - - print("\nComparison stats:") - print(f"Max abs diff: {abs_diff.max().item():.6e}") - print(f"Mean abs diff: {abs_diff.mean().item():.6e}") - print(f"Max rel diff: {rel_diff.max().item():.6e}") - print(f"Mean rel diff: {rel_diff.mean().item():.6e}") - - # Cosine similarity and MSE - cos_sim = torch.nn.functional.cosine_similarity( - ref_f32.flatten(), fi_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_f32 - fi_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - # Strict allclose - allclose = torch.allclose(ref_f32, fi_f32, atol=atol, rtol=rtol) - print(f"\nAllclose(atol={atol}, rtol={rtol}): {allclose}") - - if not allclose: - # Show top-5 largest absolute errors - flat = abs_diff.flatten() - k = min(5, flat.numel()) - topv, topi = torch.topk(flat, k) - print("\nTop-5 absolute error locations:") - for rank in range(k): - idx = topi[rank].item() - t = idx // H - h = idx % H - print( - f" [t={t}, h={h}]: ref={ref_f32.flatten()[idx].item():.6e}, " - f"fi={fi_f32.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}" - ) - - left = (ref_f32 - fi_f32).abs() - right = atol + rtol * fi_f32.abs() - ok = left <= right - hit_ratio = ok.float().mean().item() - print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {percent * 100:.2f}%)") - - assert hit_ratio >= percent, ( - f"Hit ratio {hit_ratio * 100:.2f}% is less than required {percent * 100:.2f}%" - ) + stats_accuracy(ref_out, fi_out, atol=atol, rtol=rtol, percent=percent) if __name__ == "__main__": @@ -567,4 +755,8 @@ def test_correctness_dpsk_fp8_fused_moe( "enable_autotune": True, }, enable_pdl=True, + weight_processing={ + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + }, ) diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 35f4ad61e7..28ca3e3947 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -16,7 +16,6 @@ import pytest from abc import ABC, abstractmethod -from enum import IntEnum from typing import Dict import torch from cuda.bindings import runtime @@ -46,7 +45,7 @@ get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, ) -from flashinfer.utils import get_compute_capability +from .test_utils import skip_checks, QuantMode # Max num tokens to tune for trtllm-gen fused moe @@ -216,17 +215,6 @@ def _run_moe_computation(self, runtime_args): return output # Extract tensor from tuple -class QuantMode(IntEnum): - """Supported quantization modes for MoE testing.""" - - FP4_NVFP4_NVFP4 = 1 - FP4_MXFP4_MXFP8 = 2 - FP4_MXFP4_Bf16 = 3 - FP8_BLOCK_SCALE = 4 - FP8_PER_TENSOR = 5 - BF16 = 6 - - # ==================================================================================== # Abstract Base Class for MoE Implementations # ==================================================================================== @@ -2026,67 +2014,6 @@ def cache_permute_indices(): return _cache_permute_indices -def skip_checks( - moe_impl, - routing_config, - weight_processing, - gated_act_type, - num_tokens, - hidden_size, - intermediate_size, -): - """Common skip logic for all tests.""" - compute_capability = get_compute_capability(torch.device(device="cuda")) - if compute_capability[0] not in [10]: - pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") - # Skip incompatible combinations - if gated_act_type == GatedActType.GeGlu and ( - type(moe_impl) is not FP4Moe - or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4 - or routing_config["routing_method_type"] != RoutingMethodType.TopK - or num_tokens > 128 - ): - pytest.skip( - f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" - ) - elif gated_act_type == GatedActType.SwiGlu and ( - hidden_size > 1024 or intermediate_size > 1024 - ): - pytest.skip( - f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" - ) - - # Skip large intermediate sizes for configurations with many experts - if routing_config["num_experts"] >= 512 and intermediate_size > 512: - pytest.skip( - f"Skipping for testing speed: intermediate_size={intermediate_size} with {routing_config['num_experts']} experts" - ) - - if type(moe_impl) not in routing_config["compatible_moe_impls"]: - pytest.skip( - f"Incompatible: {moe_impl.name} + {routing_config['routing_method_type'].name}" - ) - if type(moe_impl) not in weight_processing["compatible_moe_impls"]: - pytest.skip( - f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}" - ) - if intermediate_size not in routing_config["compatible_intermediate_size"]: - pytest.skip( - f"Incompatible: intermediate_size={intermediate_size} with {routing_config['routing_method_type'].name} routing ({routing_config['num_experts']} experts)" - ) - - # TODO(jimmzhou): enable MxFP4xBf16 on SM103 - if ( - type(moe_impl) is FP4Moe - and moe_impl.quant_mode == QuantMode.FP4_MXFP4_Bf16 - and compute_capability[0] == 10 - and compute_capability[1] == 3 - ): - pytest.xfail( - "Note(jimmzhou): Make MxFP4xBf16 nonfunctional on SM103 to avoid B200 regression" - ) - - def run_moe_test( num_tokens, hidden_size, diff --git a/tests/moe/test_utils.py b/tests/moe/test_utils.py new file mode 100644 index 0000000000..ebaf85c189 --- /dev/null +++ b/tests/moe/test_utils.py @@ -0,0 +1,97 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch +from enum import IntEnum +from flashinfer import GatedActType, RoutingMethodType +from flashinfer.utils import get_compute_capability + + +class QuantMode(IntEnum): + """Supported quantization modes for MoE testing.""" + + FP4_NVFP4_NVFP4 = 1 + FP4_MXFP4_MXFP8 = 2 + FP4_MXFP4_Bf16 = 3 + FP8_BLOCK_SCALE = 4 + FP8_PER_TENSOR = 5 + BF16 = 6 + + +def skip_checks( + moe_impl, + routing_config, + weight_processing, + gated_act_type, + num_tokens, + hidden_size, + intermediate_size, +): + """Common skip logic for all tests.""" + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] not in [10]: + pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + + # Check if moe_impl is FP4Moe by class name to avoid circular imports + is_fp4_moe = type(moe_impl).__name__ == "FP4Moe" + + # Skip incompatible combinations + if gated_act_type == GatedActType.GeGlu and ( + not is_fp4_moe + or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4 + or routing_config["routing_method_type"] != RoutingMethodType.TopK + or num_tokens > 128 + ): + pytest.skip( + f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" + ) + elif gated_act_type == GatedActType.SwiGlu and ( + hidden_size > 1024 or intermediate_size > 1024 + ): + pytest.skip( + f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" + ) + + # Skip large intermediate sizes for configurations with many experts + if routing_config["num_experts"] >= 512 and intermediate_size > 512: + pytest.skip( + f"Skipping for testing speed: intermediate_size={intermediate_size} with {routing_config['num_experts']} experts" + ) + + if type(moe_impl) not in routing_config["compatible_moe_impls"]: + pytest.skip( + f"Incompatible: {moe_impl.name} + {routing_config['routing_method_type'].name}" + ) + if type(moe_impl) not in weight_processing["compatible_moe_impls"]: + pytest.skip( + f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}" + ) + if intermediate_size not in routing_config["compatible_intermediate_size"]: + pytest.skip( + f"Incompatible: intermediate_size={intermediate_size} with {routing_config['routing_method_type'].name} routing ({routing_config['num_experts']} experts)" + ) + + # TODO(jimmzhou): enable MxFP4xBf16 on SM103 + if ( + is_fp4_moe + and moe_impl.quant_mode == QuantMode.FP4_MXFP4_Bf16 + and compute_capability[0] == 10 + and compute_capability[1] == 3 + ): + pytest.xfail( + "Note(jimmzhou): Make MxFP4xBf16 nonfunctional on SM103 to avoid B200 regression" + ) From 1c4b52232434db5eb5aecd971f14e204c9c80196 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 18 Nov 2025 12:53:43 -0800 Subject: [PATCH 068/130] hotfix: rename moe/test_utils.py to moe/utils.py (#2106) --- tests/moe/test_dpsk_fused_moe_fp8.py | 4 ++-- tests/moe/test_trtllm_gen_fused_moe.py | 2 +- tests/moe/{test_utils.py => utils.py} | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename tests/moe/{test_utils.py => utils.py} (100%) diff --git a/tests/moe/test_dpsk_fused_moe_fp8.py b/tests/moe/test_dpsk_fused_moe_fp8.py index a472ecc5a0..711e05f234 100644 --- a/tests/moe/test_dpsk_fused_moe_fp8.py +++ b/tests/moe/test_dpsk_fused_moe_fp8.py @@ -7,7 +7,7 @@ WeightLayout, trtllm_fp8_block_scale_moe, ) -from .test_utils import skip_checks, QuantMode +from .utils import skip_checks, QuantMode from flashinfer import GatedActType @@ -611,7 +611,7 @@ def __init__(self): if "compatible_moe_impls" not in weight_processing: weight_processing["compatible_moe_impls"] = [type(moe_impl)] - # Use the complete skip_checks function from test_utils + # Use the complete skip_checks function from utils skip_checks( moe_impl=moe_impl, routing_config=routing_config, diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 28ca3e3947..1a78243593 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -45,7 +45,7 @@ get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, ) -from .test_utils import skip_checks, QuantMode +from .utils import skip_checks, QuantMode # Max num tokens to tune for trtllm-gen fused moe diff --git a/tests/moe/test_utils.py b/tests/moe/utils.py similarity index 100% rename from tests/moe/test_utils.py rename to tests/moe/utils.py From 219592ba27206ba38c8de8c5dd76137595d5ec45 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe <50598321+nvmbreughe@users.noreply.github.com> Date: Tue, 18 Nov 2025 20:16:59 -0600 Subject: [PATCH 069/130] [DSR1] Added MLA test (#2100) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Added DSR1 MLA test, and split up the trtllm_batch_decode_mla function. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Improved test suite for batch decoding by making maximum sequence length configurable, adding parameterized runs across short and long lengths, and introducing a compatibility wrapper to preserve legacy behavior. This enhances coverage and validation across varied sequence-length scenarios. --------- Co-authored-by: Zihao Ye --- tests/attention/test_trtllm_gen_mla.py | 91 +++++++++++++++++++++----- 1 file changed, 74 insertions(+), 17 deletions(-) diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 508fce831d..d56be03eb6 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -9,20 +9,7 @@ workspace_size = 128 * 1024 * 1024 -@pytest.mark.parametrize( - "batch_size", - [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], -) -@pytest.mark.parametrize("scale", [1.0, 0.5]) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) -@pytest.mark.parametrize("page_size", [32, 64]) -@pytest.mark.parametrize( - "q_len_per_request", [1, 2] -) # todo(Yingyi): verify larger q_len_per_request -@pytest.mark.parametrize("dynamic_scale", [False]) -@pytest.mark.parametrize("enable_pdl", [True, False, None]) -@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) -def test_trtllm_batch_decode_mla( +def trtllm_batch_decode_mla( batch_size: int, scale: float, dtype: torch.dtype, @@ -31,6 +18,7 @@ def test_trtllm_batch_decode_mla( dynamic_scale: bool, enable_pdl: bool, backend: str, + MAX_SEQ_LEN: int, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if backend == "xqa": @@ -49,9 +37,6 @@ def test_trtllm_batch_decode_mla( torch.manual_seed(42) device = "cuda:0" - # Fixed max sequence length - MAX_SEQ_LEN = 1024 - # Deepseek attention config (decode-MLA) num_q_heads = 128 qk_nope_head_dim = 128 @@ -239,3 +224,75 @@ def test_trtllm_batch_decode_mla( f"Total {o_ref.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, " f"require at least {required_ratio:.1%}" ) + + +@pytest.mark.parametrize( + "batch_size", + [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], +) +@pytest.mark.parametrize("scale", [1.0, 0.5]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [32, 64]) +@pytest.mark.parametrize( + "q_len_per_request", [1, 2] +) # todo(Yingyi): verify larger q_len_per_request +@pytest.mark.parametrize("dynamic_scale", [False]) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) +def test_trtllm_batch_decode_mla( + batch_size: int, + scale: float, + dtype: torch.dtype, + page_size: int, + q_len_per_request: int, + dynamic_scale: bool, + enable_pdl: bool, + backend: str, +): + trtllm_batch_decode_mla( + batch_size, + scale, + dtype, + page_size, + q_len_per_request, + dynamic_scale, + enable_pdl, + backend, + 1024, + ) + + +@pytest.mark.parametrize( + "batch_size", + [2, 4, 8], +) +@pytest.mark.parametrize("scale", [1.0, 0.5]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [64]) +@pytest.mark.parametrize("q_len_per_request", [1, 2, 3]) +@pytest.mark.parametrize("dynamic_scale", [False]) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("backend", ["trtllm-gen"]) +@pytest.mark.parametrize("MAX_SEQ_LEN", [1024, 8960]) +def test_dsr1_trtllm_mla( + batch_size: int, + scale: float, + dtype: torch.dtype, + page_size: int, + q_len_per_request: int, + dynamic_scale: bool, + enable_pdl: bool, + backend: str, + MAX_SEQ_LEN: int, +): + trtllm_batch_decode_mla( + batch_size, + scale, + dtype, + page_size, + q_len_per_request, + dynamic_scale, + enable_pdl, + backend, + MAX_SEQ_LEN, + ) From b9964cc30edd0ddac84e68ef2528eeefb59c96e2 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Tue, 18 Nov 2025 21:10:28 -0800 Subject: [PATCH 070/130] test: Enable testing for trtllm-gen decode bs1 (#2103) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description In #1898, it was raised that trtllm-gen's attention kernels fail for batch size 1. The prefill kernel was fixed in #1912 and prefill tests have been enabled. Further updates to trtllm-gen kernels have also fixed the decode batch size 1 issue. Current PR re-enables testing. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Expanded batch_decode test scenarios to cover additional small-batch and page-size combinations. * Increased coverage for max_in_kv_len by testing multiple length options instead of a single value. * Restored previously marked-as-expected-failure case to run normally, improving overall test pass coverage. --------- Co-authored-by: Zihao Ye --- tests/attention/test_trtllm_gen_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 642c437e59..4e2c615827 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1041,6 +1041,7 @@ def test_trtllm_batch_decode( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", [ (1, 1, 16, 8, 8), + (1, 1, 32, 8, 8), ], ) @pytest.mark.parametrize("window_left", [-1]) @@ -1052,7 +1053,7 @@ def test_trtllm_batch_decode( ) @pytest.mark.parametrize("enable_pdl", [None]) @pytest.mark.parametrize("enable_sink", [False]) -@pytest.mark.parametrize("max_in_kv_len", [8192]) +@pytest.mark.parametrize("max_in_kv_len", [4096, 8192]) @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("device_scale", [True, False]) def test_trtllm_batch_decode_bs1( @@ -1073,7 +1074,6 @@ def test_trtllm_batch_decode_bs1( device_scale, ): # Small number of test cases for batch size 1 - pytest.xfail("trtllm-gen decode gets incorrect output with bs1") _test_trtllm_batch_decode( "trtllm-gen", kv_layout, From 3a234053930f869676488b551d41af8c6a705e2a Mon Sep 17 00:00:00 2001 From: nv-yunzheq Date: Wed, 19 Nov 2025 02:33:24 -0800 Subject: [PATCH 071/130] [DSV3] Optimized routing kernels dsv3 (#2099) --- csrc/fused_moe/moeTopKFuncs.cuh | 254 +++++++++ csrc/fused_moe/noAuxTcKernels.cu | 450 ++++++++++++++++ flashinfer/dsv3_ops/__init__.py | 2 + flashinfer/fused_moe/__init__.py | 5 + flashinfer/fused_moe/fused_routing_dsv3.py | 194 +++++++ flashinfer/jit/__init__.py | 3 + flashinfer/jit/dsv3_optimizations.py | 34 ++ .../trtllm/fused_moe/noAuxTcKernels.h | 33 ++ .../test_dsv3_fused_routing.py | 501 ++++++++++++++++++ 9 files changed, 1476 insertions(+) create mode 100644 csrc/fused_moe/moeTopKFuncs.cuh create mode 100644 csrc/fused_moe/noAuxTcKernels.cu create mode 100644 flashinfer/fused_moe/fused_routing_dsv3.py create mode 100644 include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h create mode 100644 tests/model_optimizations/test_dsv3_fused_routing.py diff --git a/csrc/fused_moe/moeTopKFuncs.cuh b/csrc/fused_moe/moeTopKFuncs.cuh new file mode 100644 index 0000000000..e34c5f2665 --- /dev/null +++ b/csrc/fused_moe/moeTopKFuncs.cuh @@ -0,0 +1,254 @@ + +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#ifndef TRTLLM_MOETOPKFUNCS_CUH_H +#define TRTLLM_MOETOPKFUNCS_CUH_H + +#include +#include + +#include + +#include "flashinfer/arch_condition.h" + +namespace tensorrt_llm::kernels { + +namespace reduce_topk { +namespace cg = cooperative_groups; +static constexpr int kWARP_SIZE = 32; +static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = flashinfer::arch::is_major_v<10>; + +template +struct TopKRedType { + using T = T_; + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "Top K reduction only implemented for int, float, float16 and bfloat16"); + + using TypeCmp = std::conditional_t; + using IdxT = std::conditional_t; + + static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16; + static constexpr int kMaxIdx = 65535; + TypeCmp compValIdx; + + static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) { + auto valueBits = + cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val)); + TypeCmp compactTmp = valueBits; + compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx)); + // Use 65535 minus idx to give higher priority to elements with smaller indices. + return compactTmp; + } + + static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp) { + // Since โ€œ65535-idxโ€ is always smaller than 65536 and positive, we can directly use it as the + // lower 16 bits + index = kMaxIdx - static_cast((cmp & 0xFFFF)); + + auto compactTmp = cmp >> kMoveBits; + auto valueBits = cub::Traits::TwiddleOut( + reinterpret_cast::UnsignedBits&>(compactTmp)); + value = reinterpret_cast(valueBits); + } + + __host__ __device__ TopKRedType() = default; + + __host__ __device__ TopKRedType(T val, int32_t idx) : compValIdx(makeCmpVal(val, idx)) {} + + __host__ __device__ operator TypeCmp() const noexcept { return compValIdx; } + + __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) { + if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8) { + return cg::reduce(warp, compValIdx, cg::greater{}); + } else { + TypeCmp result; + asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx)); + return result; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TopKIdx { + // by default, empty +}; + +template +struct TopKIdx { + static constexpr int K = K_; + int32_t val[K]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define TOPK_SWAP(I, J) \ + { \ + auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \ + auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \ + topK[I].compValIdx = pairMax; \ + topK[J].compValIdx = pairMin; \ + } + +template +struct Sort; + +template +struct Sort<1, RedType> { + static __device__ void run(RedType* topK) {} +}; + +template +struct Sort<2, RedType> { + static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); } +}; + +template +struct Sort<3, RedType> { + static __device__ void run(RedType* topK) { + TOPK_SWAP(0, 1); + TOPK_SWAP(1, 2); + TOPK_SWAP(0, 1); + } +}; + +template +struct Sort<4, RedType> { + static __device__ void run(RedType* topK) { + TOPK_SWAP(0, 2); + TOPK_SWAP(1, 3); + TOPK_SWAP(0, 1); + TOPK_SWAP(2, 3); + TOPK_SWAP(1, 2); + } +}; + +template +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, + Type (&out)[K], int32_t (&outIdx)[K], Type value, + int32_t idx, Type const minValue, int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + using RedType = TopKRedType; + RedType topK{value, idx}; + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct + { + topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK; + // get the next largest value + packedMax = topK.reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +template +__device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Type (&out)[K], + int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], + Type minValue, int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert(N < 5, "Only support candidates number less than or equal to 128"); + using RedType = TopKRedType; + RedType topK[N]; +#pragma unroll + for (int nn = 0; nn < N; ++nn) { + topK[nn] = RedType{value[nn], idx[nn]}; + } + + if constexpr (!IsSorted) { + Sort::run(topK); + } + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < actualK; ++kk) { + bool update = kk > 0 && packedMax == topK[0].compValIdx; +#pragma unroll + for (int nn = 0; nn < N; ++nn) { + topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} + : update ? topK[nn + 1] + : topK[nn]; + } + // get the next largest value + packedMax = topK[0].reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +template +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, + Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], + int32_t (&idx)[N], Type const minValue, + int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512"); + static_assert( + N <= 4 || N % 4 == 0, + "Only support candidates number is a multiple of 4*32=128 or less than or equal to 4"); + using RedType = TopKRedType; + + if constexpr (N <= 4) { + reduceTopKFunc(warp, out, outIdx, value, idx, minValue, actualK); + } else { + constexpr int numLoops = N / 4; + constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1; + + Type topKBufferValue[numResults]; + int32_t topKBufferIdx[numResults]; + int32_t laneIdx = threadIdx.x % kWARP_SIZE; + + for (int ii = 0; ii < numResults; ++ii) { + topKBufferValue[ii] = minValue; + topKBufferIdx[ii] = ii * kWARP_SIZE - 1; //@todo: check if this is correct + } + for (int loop = 0; loop < numLoops; ++loop) { + int start = loop * 4; + Type topKValue[K]; + int32_t topKIdx[K]; + Type inValue[4]; + int32_t inIdx[4]; + for (int i = 0; i < 4; ++i) { + inValue[i] = value[start + i]; + inIdx[i] = idx[start + i]; + } + reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK); + int inOffset = laneIdx % K; + if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) { + topKBufferValue[0] = topKValue[inOffset]; + topKBufferIdx[0] = topKIdx[inOffset]; + } + if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE))) { + topKBufferValue[1] = topKValue[inOffset]; + topKBufferIdx[1] = topKIdx[inOffset]; + } + } + + reduceTopKFunc(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, + actualK); + } +}; + +#undef TOPK_SWAP + +} // namespace reduce_topk +} // namespace tensorrt_llm::kernels +#endif // TRTLLM_MOETOPKFUNCS_CUH_H diff --git a/csrc/fused_moe/noAuxTcKernels.cu b/csrc/fused_moe/noAuxTcKernels.cu new file mode 100644 index 0000000000..1f57d9b57b --- /dev/null +++ b/csrc/fused_moe/noAuxTcKernels.cu @@ -0,0 +1,450 @@ +#include +#include + +#include + +#include "flashinfer/trtllm/fused_moe/noAuxTcKernels.h" +#include "moeTopKFuncs.cuh" +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include "tensorrt_llm/common/envUtils.h" +#include "tvm_ffi_utils.h" + +namespace cg = cooperative_groups; +using namespace tensorrt_llm::common; + +namespace tensorrt_llm::kernels { +static constexpr int WARP_SIZE = 32; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxNumExpertsUnit = 128; +static constexpr int NumTopGroupScores = 2; +static constexpr int MaxNumTopExperts = 8; +static constexpr int MaxNumTopGroups = 4; + +static __device__ inline float sigmoid_accurate(float x) { return 0.5f * tanhf(0.5f * x) + 0.5f; } + +template +__global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, IdxT* topkIndices, + BiasT* routingBias, int64_t const numTokens, + int64_t const numGroup, int64_t const topkGroup, + int64_t const topk, int64_t const numExperts, + int64_t const numExpertsPerGroup, + double const routedScalingFactor) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // declare shared memory structure + // number of experts is bounded by number of threads + __shared__ float __attribute((aligned(128))) smemScoreSigmoid[MaxNumExperts]; + __shared__ float __attribute((aligned(128))) smemScoreBias[MaxNumExperts]; + // number of expert groups is bounded by number of warps + int constexpr NumWarps = MaxNumExperts / WARP_SIZE; + __shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps]; + + // needed for warp reduce + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // for the final reduction of weight norm, only some lanes need to participate + int32_t laneIdx = threadIdx.x % WARP_SIZE; + int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0); + + if constexpr (UseGroups) { + if (warpIdx >= numGroup) { + return; + } + } + + // note that for invalid scores, we simply use a negative value: + // they work well even with the compacted format used in topK, and + // sigmoid / bias activated scores cannot be negative + static constexpr float invalidScoreFloat = float{-INFINITY}; + const OutputT invalidScore = OutputT{invalidScoreFloat}; + + // load bias already; each warp represents one expert group + auto threadExpert = threadIdx.x; + bool expertSelected = threadExpert < numExperts; + if constexpr (UseGroups) { + threadExpert = warpIdx * numExpertsPerGroup + laneIdx; + expertSelected = laneIdx < numExpertsPerGroup; + } + + auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert; + auto biasVal = expertSelected ? static_cast(routingBias[threadExpert]) : invalidScoreFloat; + topkValues += blockIdx.x * topk; + topkIndices += blockIdx.x * topk; + + // get our assigned thread score; each warp represents one expert group + float score = expertSelected ? static_cast(scores[scoreIdx]) : invalidScoreFloat; + auto scoreSigmoid = sigmoid_accurate(score); + // write the sigmoid score to shared for later use + if (expertSelected) { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + + // get the score with bias + // note that with invalid values, because sigmoid is < 1 and bias is -1, + // we must get a negative value, which is smaller than any valid value + auto scoreBias = float{scoreSigmoid + float{biasVal}}; + + if (expertSelected) { + smemScoreBias[threadExpert] = scoreBias; + } + + // registers for top group score reduction + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; + float topGroups[MaxNumTopGroups]; // bound of numGroup + int32_t topGroupIdx[MaxNumTopGroups]; + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; + float topScores[MaxNumTopExperts]; // bound of topk + int32_t topExperts[MaxNumTopExperts]; + + if constexpr (UseGroups) { + reduce_topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, + /* minValue */ invalidScoreFloat); + + // get the final group score and write it to shared + if (laneIdx == 0) { + auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; + smemGroupScores[warpIdx] = groupScore; + } + } + + // make group scores available to all warps + __syncthreads(); + + if constexpr (UseGroups) { + if (warpIdx == 0) { + // a single warp performs the selection of top groups, and goes on to select the final experts + float groupScore = laneIdx < numGroup ? smemGroupScores[laneIdx] : invalidScoreFloat; + + reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + /* minValue */ invalidScoreFloat); + + // final expert selection: get relevant indexes and scores from shared + +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of numGroup + // auto groupIdx = topGroupIdx[ii]; + auto groupIdx = (ii < topkGroup) ? topGroupIdx[ii] : 0; + expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx; + + expertScoreGroup[ii] = (ii < topkGroup) && expertSelected + ? smemScoreBias[expertIdxGroup[ii]] + : invalidScoreFloat; + } + + tensorrt_llm::kernels::reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, + expertIdxGroup, + /* minValue */ invalidScoreFloat, topk); + } + } else if constexpr (MaxNumExperts > MaxNumExpertsUnit) { + // without groups, and the expert number is larger than MaxNumExpertsUnit, + // we need to use multiple warps to calculate the intermediate topk results + + int constexpr NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; + int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; + __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; + __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; + if (warpIdx < NumExpertWarps) { + int offset = warpIdx * WARP_SIZE * MaxNumTopGroups; +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = + offset + expertIdx < numExperts ? smemScoreBias[offset + expertIdx] : invalidScoreFloat; + } + reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, topk); + + if (laneIdx < topk) { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + } + } + __syncthreads(); + if (warpIdx == 0) { + int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1; + float intermidiateScore[NumInterTopKPerThread]; + int32_t intermidiateExpert[NumInterTopKPerThread]; + for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) { + int ii = i / WARP_SIZE; + if (i < NumInterTopK) { + intermidiateScore[ii] = smemInterTopScores[i]; + intermidiateExpert[ii] = smemInterTopExperts[i]; + } else { + intermidiateScore[ii] = invalidScoreFloat; + intermidiateExpert[ii] = MaxNumExperts - 1; + } + } + reduce_topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert, + /* minValue */ invalidScoreFloat, topk); + } + } else { + // without groups, and the expert number is smaller than MaxNumExpertsUnit + // each thread just takes `MaxNumTopGroups` experts + if (warpIdx == 0) { +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] = + expertIdx < numExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + } + reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, topk); + } + } + + if (warpIdx == 0) { + // determine our lane's expert index and write to output + int32_t expertIdx = laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + // norm the value + float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = static_cast(scoreNorm * routedScalingFactor / (redNorm + 1e-20)); + // store the topk scores and experts to output + if (laneIdx < topk) { + topkValues[laneIdx] = static_cast(finalScore); + topkIndices[laneIdx] = expertIdx; + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, + int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, + bool const launch_with_pdl, cudaStream_t const stream) { + // Check if we can use the optimized deepseek_v3_topk_kernel + bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts); + + int64_t const experts_per_group = num_experts / n_group; + bool const is_multi_group = (n_group != 1) && (num_experts <= NumDeepseekExperts) && + (experts_per_group <= WARP_SIZE) && + (experts_per_group * topk_group <= MaxNumExpertsUnit); + + if (is_single_group || is_multi_group) { + cudaLaunchConfig_t config; + auto* kernel_instance = + &deepseek_v3_topk_kernel; + int num_threads = NumDeepseekExperts; + if (is_single_group) { + if (num_experts > MaxNumExpertsUnit) { + kernel_instance = + &deepseek_v3_topk_kernel; + num_threads = NumKimiK2Experts; + } else { + kernel_instance = + &deepseek_v3_topk_kernel; + num_threads = MaxNumExpertsUnit; + } + } + + config.gridDim = num_tokens; + config.blockDim = num_threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = launch_with_pdl; + config.numAttrs = 1; + config.attrs = attrs; + + cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group, + routed_scaling_factor); + sync_check_cuda_error(stream); + } else { + // TODO: call the generic path (previous implementation) or signal unsupported config. + TLLM_CHECK_WITH_INFO(false, + "invokeNoAuxTc: unsupported configuration (n_group=%ld, num_experts=%ld, " + "topk_group=%ld). Please use " + "original pytorch implementation.", + n_group, num_experts, topk_group); + } +} + +#define INSTANTIATE_NOAUX_TC(InputT, BiasT, OutputT, IdxT) \ + template void invokeNoAuxTc( \ + InputT * scores, BiasT * bias, OutputT * topk_values, IdxT * topk_indices, \ + int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, \ + int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, \ + bool const launch_with_pdl, cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC(float, float, float, int32_t); +INSTANTIATE_NOAUX_TC(float, half, float, int32_t); + +INSTANTIATE_NOAUX_TC(half, float, half, int32_t); +INSTANTIATE_NOAUX_TC(half, half, half, int32_t); + +#ifdef ENABLE_BF16 +INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, float, int32_t); +INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, half, int32_t); + +INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, __nv_bfloat16, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, __nv_bfloat16, int32_t); +#endif + +} // namespace tensorrt_llm::kernels + +namespace flashinfer::trtllm_dsv3_fused_routing { +// th::Tensor const& scores, th::Tensor const& bias, int64_t n_group, +// int64_t topk_group, int64_t topk, double routed_scaling_factor +// th::Tensor topk_values, th::Tensor topk_indices + +void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_group, int64_t topk, + double routed_scaling_factor, TensorView topk_values, TensorView topk_indices, + bool launch_with_pdl) { + auto data_type = scores.dtype(); + auto bias_type = bias.dtype(); + + auto input_size = scores.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + + TVM_FFI_ICHECK(input_size.size() == 2) << "scores must be a 2D Tensor"; + TVM_FFI_ICHECK((scores.device().device_type == kDLCUDA) && (bias.device().device_type == kDLCUDA)) + << "scores and bias must be CUDA tensors"; + TVM_FFI_ICHECK(scores.device().device_id == bias.device().device_id) + << "scores and bias must be on the same device"; + TVM_FFI_ICHECK(bias.dim() == 1 && bias.numel() == num_experts) + << "bias must be 1D with length == number of experts (%ld)"; + TVM_FFI_ICHECK(num_experts % n_group == 0) << "num_experts should be divisible by n_group"; + TVM_FFI_ICHECK(n_group <= 32) + << "n_group should be smaller than or equal to 32 for now"; //@todo: remove this restriction + // later + TVM_FFI_ICHECK(topk <= 32) + << "topk should be smaller than or equal to 32 for now"; //@todo: remove this restriction + // later + TVM_FFI_ICHECK(topk_values.dim() == 2) << "topk_values must be a 2D Tensor"; + TVM_FFI_ICHECK(topk_indices.dim() == 2) << "topk_indices must be a 2D Tensor"; + TVM_FFI_ICHECK(topk_values.sizes()[0] == num_tokens) + << "topk_values must have the same number of tokens as scores"; + TVM_FFI_ICHECK(topk_indices.sizes()[0] == num_tokens) + << "topk_indices must have the same number of tokens as scores"; + TVM_FFI_ICHECK(topk_values.sizes()[1] == topk) + << "topk_values must have the same number of topk as scores"; + TVM_FFI_ICHECK(topk_indices.sizes()[1] == topk) + << "topk_indices must have the same number of topk as scores"; + TVM_FFI_ICHECK(topk_values.dtype() == data_type) + << "topk_values must have the same dtype as scores"; + TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) + << "topk_indices must have the same dtype as scores"; + + auto stream = get_stream(scores.device()); + using namespace tensorrt_llm::kernels; + switch (encode_dlpack_dtype(data_type)) { + case float16_code: + // Handle Float16 + switch (encode_dlpack_dtype(bias_type)) { + case float16_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case float32_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case bfloat16_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + default: + throw std::invalid_argument( + "Invalid bias dtype, only supports float16, float32, and bfloat16"); + break; + } + break; + case float32_code: + switch (encode_dlpack_dtype(bias_type)) { + case float32_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case float16_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), reinterpret_cast(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case bfloat16_code: + invokeNoAuxTc( + reinterpret_cast(scores.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), + reinterpret_cast(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + default: + throw std::invalid_argument( + "Invalid bias dtype, only supports float16, float32, and bfloat16"); + break; + } + break; + case bfloat16_code: + // Handle BFloat16 + switch (encode_dlpack_dtype(bias_type)) { + case bfloat16_code: + invokeNoAuxTc<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case float16_code: + invokeNoAuxTc<__nv_bfloat16, half, __nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + case float32_code: + invokeNoAuxTc<__nv_bfloat16, float, __nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), num_tokens, num_experts, n_group, + topk_group, topk, routed_scaling_factor, launch_with_pdl, stream); + break; + default: + throw std::invalid_argument( + "Invalid bias dtype, only supports bfloat16, float16, and float32"); + break; + } + break; + default: + // Handle other data types + throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16"); + break; + } +} +TVM_FFI_DLL_EXPORT_TYPED_FUNC(NoAuxTc, flashinfer::trtllm_dsv3_fused_routing::NoAuxTc); +} // namespace flashinfer::trtllm_dsv3_fused_routing diff --git a/flashinfer/dsv3_ops/__init__.py b/flashinfer/dsv3_ops/__init__.py index 49fb43b3ec..05a7c4e657 100644 --- a/flashinfer/dsv3_ops/__init__.py +++ b/flashinfer/dsv3_ops/__init__.py @@ -1,5 +1,7 @@ from flashinfer.gemm import mm_M1_16_K7168_N256 +from flashinfer.fused_moe import NoAuxTc __all__ = [ "mm_M1_16_K7168_N256", + "NoAuxTc", ] diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index 84e3ade9c7..87c207f5e0 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -33,6 +33,10 @@ trtllm_bf16_moe, ) +from .fused_routing_dsv3 import ( # noqa: F401 + NoAuxTc as NoAuxTc, +) + __all__ = [ "RoutingMethodType", "GatedActType", @@ -50,4 +54,5 @@ "trtllm_fp4_block_scale_routed_moe", "trtllm_fp8_block_scale_moe", "trtllm_fp8_per_tensor_scale_moe", + "NoAuxTc", ] diff --git a/flashinfer/fused_moe/fused_routing_dsv3.py b/flashinfer/fused_moe/fused_routing_dsv3.py new file mode 100644 index 0000000000..bb12472272 --- /dev/null +++ b/flashinfer/fused_moe/fused_routing_dsv3.py @@ -0,0 +1,194 @@ +from flashinfer.jit import gen_dsv3_fused_routing_module +import functools +from types import SimpleNamespace +import torch +from flashinfer.utils import ( + register_custom_op, + supported_compute_capability, + backend_requirement, +) + + +@supported_compute_capability([89, 90, 100, 103, 120, 121]) +def _check_dsv3_fused_routing_supported( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl, +): + """Validate configuration parameters for DSv3 fused routing kernel. + + Args: + scores: Input routing scores tensor + bias: Per-expert routing bias tensor + n_group: Number of expert groups + topk_group: Number of top groups to select + topk: Number of top experts to select per token + routed_scaling_factor: Scaling factor for normalized weights + topk_values: Output tensor for normalized expert weights + topk_indices: Output tensor for selected expert indices + launch_with_pdl: Whether to use Persistent Device-side Launch + + Raises: + ValueError: If configuration is invalid or exceeds kernel limits + """ + # Extract number of experts from scores shape + num_experts = scores.shape[1] + + # Check basic configuration constraints + if topk_group * n_group < topk or topk_group > n_group: + raise ValueError( + f"Invalid configuration: topk_group * n_group ({topk_group * n_group}) must be >= topk ({topk}) " + f"and topk_group ({topk_group}) must be <= n_group ({n_group})" + ) + + # Check kernel limits based on number of groups + if n_group > 1: + experts_per_group = num_experts / n_group + max_experts_in_selected_groups = experts_per_group * topk_group + + if topk > 8: + raise ValueError( + f"Invalid configuration for n_group > 1: topk ({topk}) must be <= 8" + ) + if experts_per_group > 32: + raise ValueError( + f"Invalid configuration for n_group > 1: num_experts / n_group " + f"({experts_per_group}) must be <= 32" + ) + if max_experts_in_selected_groups > 128: + raise ValueError( + f"Invalid configuration for n_group > 1: num_experts / n_group * topk_group " + f"({max_experts_in_selected_groups}) must be <= 128" + ) + else: # n_group == 1 + if num_experts > 384: + raise ValueError( + f"Invalid configuration for n_group = 1: num_experts ({num_experts}) must be <= 384" + ) + if topk > 8: + raise ValueError( + f"Invalid configuration for n_group = 1: topk ({topk}) must be <= 8" + ) + + return True + + +@functools.cache +def get_dsv3_fused_routing_module(): + module = gen_dsv3_fused_routing_module().build_and_load() + + @register_custom_op( + "flashinfer::NoAuxTc", + mutates_args=["topk_values", "topk_indices"], + ) + def NoAuxTc( + scores: torch.Tensor, + bias: torch.Tensor, + n_group: int, + topk_group: int, + topk: int, + routed_scaling_factor: float, + topk_values: torch.Tensor, + topk_indices: torch.Tensor, + launch_with_pdl: bool = True, + ) -> None: + module.NoAuxTc( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl, + ) + + return SimpleNamespace( + NoAuxTc=NoAuxTc, + ) + + +@backend_requirement({}, common_check=_check_dsv3_fused_routing_supported) +def NoAuxTc( + scores: torch.Tensor, + bias: torch.Tensor, + n_group: int, + topk_group: int, + topk: int, + routed_scaling_factor: float, + topk_values: torch.Tensor, + topk_indices: torch.Tensor, + launch_with_pdl: bool = True, +) -> None: + """Fused expert routing with top-k selection for DeepSeek-V3. + + This function performs a highly optimized fused routing operation specifically + designed for DeepSeek-V3's Mixture of Experts (MoE) architecture with grouped + expert routing and no auxiliary loss. It combines score computation, expert + selection, and normalization into a single kernel operation. + + The routing algorithm consists of the following steps: + 1. Compute biased scores: sigmoid(scores) + bias for each expert + 2. Group experts and compute group scores (sum of top-2 experts per group) + 3. Select top-k groups based on group scores + 4. From selected groups, select top-k experts based on biased scores + 5. Normalize selected expert weights: sigmoid_scores / sum(sigmoid_scores) * scale + + Args: + scores (torch.Tensor): Input routing scores of shape (num_tokens, num_experts). + The logits produced by the router network before activation. Supports + bfloat16, float16, or float32. + bias (torch.Tensor): Per-expert routing bias of shape (num_experts,). Added to + sigmoid-activated scores to produce biased scores for expert selection. + Must match the dtype of scores. + n_group (int): Number of expert groups. Experts are divided into groups for + hierarchical selection. Typical value is 8 for DeepSeek-V3 with 256 experts + (32 experts per group). + topk_group (int): Number of top groups to select. Must be <= n_group. Typical + value is 4, meaning the top 4 groups are selected from 8 groups. + topk (int): Number of top experts to select per token. Must be <= num_experts. + Typical value is 8, meaning 8 experts are routed per token. + routed_scaling_factor (float): Scaling factor applied to normalized expert + weights. The final output weights are: + sigmoid_scores / sum(sigmoid_scores) * routed_scaling_factor. + topk_values (torch.Tensor): Pre-allocated output tensor of shape + (num_tokens, topk) for the normalized expert weights. Must be float32. + This tensor is mutated in-place. + topk_indices (torch.Tensor): Pre-allocated output tensor of shape + (num_tokens, topk) for the selected expert indices. Must be int32 or int64. + This tensor is mutated in-place. + launch_with_pdl (bool, optional): Whether to launch the kernel using Persistent + Device-side Launch. Defaults to True. + + Returns: + None: Results are written directly to `topk_values` and `topk_indices` tensors. + + Note: + - The kernel uses float32 internally for all computations to ensure numerical + precision, even when inputs are float16 or bfloat16. + - This implementation is optimized for Hopper (compute capability 90, 100), + Ada (compute capability 89), and Blackwell (compute capability 120, 121) + architectures. + - The "NoAux" prefix indicates this variant does not compute auxiliary losses + (e.g., load balancing loss) during routing. + - The "Tc" suffix indicates the use of Tensor Core optimizations in the + underlying CUDA kernel. + """ + get_dsv3_fused_routing_module().NoAuxTc( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl, + ) diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 1aa6f44dbd..0bacf2d28b 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -80,6 +80,9 @@ from .dsv3_optimizations import ( gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module, ) +from .dsv3_optimizations import ( + gen_dsv3_fused_routing_module as gen_dsv3_fused_routing_module, +) cuda_lib_path = os.environ.get( diff --git a/flashinfer/jit/dsv3_optimizations.py b/flashinfer/jit/dsv3_optimizations.py index 88be890699..9aa720fa59 100644 --- a/flashinfer/jit/dsv3_optimizations.py +++ b/flashinfer/jit/dsv3_optimizations.py @@ -9,3 +9,37 @@ def gen_dsv3_router_gemm_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / "dsv3_router_gemm.cu", ], ) + + +def gen_dsv3_fused_routing_module() -> JitSpec: + return gen_jit_spec( + "dsv3_fused_routing", + [ + jit_env.FLASHINFER_CSRC_DIR / "fused_moe/noAuxTcKernels.cu", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/memoryUtils.cu", + ], + extra_include_paths=[ + jit_env.FLASHINFER_CSRC_DIR / "nv_internal", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "cutlass_extensions" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels", + ], + ) diff --git a/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h b/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h new file mode 100644 index 0000000000..5af8fe39db --- /dev/null +++ b/include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm::kernels { + +template +void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, + int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, + cudaStream_t const stream = 0); + +} // namespace tensorrt_llm::kernels diff --git a/tests/model_optimizations/test_dsv3_fused_routing.py b/tests/model_optimizations/test_dsv3_fused_routing.py new file mode 100644 index 0000000000..1749e94f46 --- /dev/null +++ b/tests/model_optimizations/test_dsv3_fused_routing.py @@ -0,0 +1,501 @@ +""" +Test for NoAuxTc (DSv3 Fused Routing) Kernel + +This test validates the NoAuxTc kernel against a reference implementation, +accounting for numerical precision and tie-breaking differences. + +================================================================================ +DSv3 ROUTING ALGORITHM +================================================================================ + +1. Compute: sigmoid(scores) + bias for each expert (biased scores) +2. Group experts and compute group scores (sum of top-2 experts per group) +3. Select top-k groups based on group scores +4. From selected groups, select top-k experts based on biased scores +5. Normalize selected experts: sigmoid_scores / sum(sigmoid_scores) * scale + +================================================================================ +VALIDATION LOGIC FLOW +================================================================================ + +The test performs TWO stages of validation for each token: + +STAGE 1: EXPERT SELECTION VALIDATION +------------------------------------- +Checks if the kernel selected the correct (or acceptably tied) experts. + +1. Are kernel_experts == ref_experts (same set)? + YES โ†’ โœ… VALID (status: "exact") + Continue to Stage 2 to validate output values + NO โ†’ Continue to step 2 + +2. Are kernel_groups == ref_groups (same groups selected)? + YES โ†’ Continue to step 3 (same groups, different experts) + NO โ†’ Continue to step 4 (different groups) + +3. SAME GROUPS, DIFFERENT EXPERTS + Check if the differing experts have tied scores: + - Compute score_diff = max(diff_expert_scores) - min(diff_expert_scores) + - If score_diff < expert_tie_threshold: + โ†’ โœ… VALID (status: "tied_experts") + - Else: + โ†’ โŒ INVALID (status: "score_mismatch") + +4. DIFFERENT GROUPS + a) Are the groups tied? + - Compute all group scores (sum of top-2 experts per group) + - Check if differing groups have similar scores + - If group_score_diff < group_tie_threshold: + โ†’ Groups are tied, continue to step 4b + - Else: + โ†’ โŒ INVALID (status: "different_groups") + + b) Are the experts correct within kernel's groups? + - Compute expected_experts = top-k experts from kernel's selected groups + - If kernel_experts == expected_experts: + โ†’ โœ… VALID (status: "tied_groups") + - Else, check if differing experts have tied scores: + - Compute score_diff for differing experts + - If score_diff < expert_tie_threshold: + โ†’ โœ… VALID (status: "tied_groups") + - Else: + โ†’ โŒ INVALID (status: "tied_groups_but_wrong_experts") + +STAGE 2: OUTPUT VALUE VALIDATION +--------------------------------- +For tokens where the SAME experts were selected (status: "exact"): +- Compare kernel output values vs reference output values +- Both are normalized scores: sigmoid_scores / sum(sigmoid_scores) * scale +- Check: abs(kernel_values - ref_values) within tolerance + - If within tolerance โ†’ โœ… VALID + - Else โ†’ โŒ INVALID (value mismatch) + +For tokens where DIFFERENT experts were selected (even if acceptably): +- SKIP value validation +- Reason: Different experts โ†’ different normalization sum โ†’ different values +- The expert selection validation already confirmed correctness + +Tolerance (data-type dependent): +- bfloat16: rtol=0.1, atol=0.1 +- float16: rtol=0.05, atol=0.05 +- float32: rtol=0.01, atol=0.01 + +================================================================================ +KEY CONCEPTS +================================================================================ + +1. **Group Ties**: When two groups have similar group scores (within threshold), + selecting either group is valid. The kernel may pick a different group than + the reference due to tie-breaking. + +2. **Expert Ties**: When experts have similar biased scores (within threshold), + selecting any of them is valid. The kernel may pick different experts due + to tie-breaking. + +3. **Tied Groups โ†’ Verify Experts**: When different groups are selected due to + ties, we must still verify that the kernel selected the correct top-k experts + WITHIN its chosen groups (not compare across different groups). + +4. **Float32 Internal Computation**: The kernel computes internally in float32 + even when inputs are float16/bfloat16. The reference must match this to + ensure consistent group/expert selection. + +================================================================================ +THRESHOLDS (Data-Type Dependent) +================================================================================ + + Expert Tie Group Tie + Threshold Threshold + bfloat16: 1.0 0.05 + float16: 0.5 0.02 + float32: 0.2 0.01 + +Group thresholds are higher because group scores are sums of 2 values, +accumulating more numerical error. + +================================================================================ +""" + +import torch +import pytest +from flashinfer.dsv3_ops import NoAuxTc +# from flashinfer.utils import get_compute_capability + + +class DSv3RoutingGroundTruth: + """ + Computes and stores all ground truth data for DSv3 routing. + Performs all computations in float32 to match kernel behavior. + """ + + def __init__( + self, scores, bias, n_group, topk_group, topk, routed_scaling_factor, data_type + ): + self.num_tokens = scores.shape[0] + self.num_experts = scores.shape[1] + self.n_group = n_group + self.topk_group = topk_group + self.topk = topk + self.routed_scaling_factor = routed_scaling_factor + self.experts_per_group = self.num_experts // n_group + self.device = scores.device + + # Set thresholds based on data type + if data_type == torch.bfloat16: + self.expert_tie_threshold = 1.0 + self.group_tie_threshold = 0.05 + elif data_type == torch.float16: + self.expert_tie_threshold = 0.5 + self.group_tie_threshold = 0.02 + else: # float32 + self.expert_tie_threshold = 0.2 + self.group_tie_threshold = 0.01 + + # Convert to float32 to match kernel's internal computation + scores_f32 = scores.to(torch.float32) + bias_f32 = bias.to(torch.float32) + + # Compute sigmoid and biased scores + self.sigmoid_scores = torch.sigmoid(scores_f32) + self.biased_scores = self.sigmoid_scores + bias_f32 + + # Reshape for group-wise operations + scores_reshaped = self.biased_scores.view( + self.num_tokens, n_group, self.experts_per_group + ) + + # Compute group scores (sum of top-2 experts per group) + top2_per_group = torch.topk( + scores_reshaped, k=2, dim=-1, largest=True, sorted=True + )[0] + self.group_scores = torch.sum(top2_per_group, dim=-1) + + # Reference group selection + _, self.ref_group_indices = torch.topk( + self.group_scores, k=topk_group, dim=-1, largest=True, sorted=True + ) + + # Identify tied groups for each token + self.tied_group_sets = [] + for token_idx in range(self.num_tokens): + tied_groups = set() + group_scores_token = self.group_scores[token_idx] + + for g1 in range(n_group): + for g2 in range(g1 + 1, n_group): + score_diff = abs(group_scores_token[g1] - group_scores_token[g2]) + if score_diff < self.group_tie_threshold: + tied_groups.add(g1) + tied_groups.add(g2) + + self.tied_group_sets.append(tied_groups) + + # Compute reference expert selection and normalization + self.ref_expert_indices = torch.zeros( + self.num_tokens, topk, dtype=torch.long, device=self.device + ) + self.ref_expert_values = torch.zeros( + self.num_tokens, topk, dtype=torch.float32, device=self.device + ) + + for token_idx in range(self.num_tokens): + # Create mask for selected groups + group_mask = torch.zeros(n_group, dtype=torch.float32, device=self.device) + group_mask[self.ref_group_indices[token_idx]] = 1.0 + expert_mask = group_mask.repeat_interleave(self.experts_per_group) + + # Mask and select top-k experts + masked_biased_scores = self.biased_scores[token_idx] * expert_mask + _, topk_idx = torch.topk( + masked_biased_scores, k=topk, dim=-1, largest=True, sorted=True + ) + + # Normalize selected experts + selected_sigmoid_scores = self.sigmoid_scores[token_idx][topk_idx] + score_sum = selected_sigmoid_scores.sum() + 1e-20 + normalized_scores = ( + selected_sigmoid_scores / score_sum * routed_scaling_factor + ) + + # Sort by normalized scores + sorted_vals, sorted_idx = torch.sort(normalized_scores, descending=True) + self.ref_expert_values[token_idx] = sorted_vals + self.ref_expert_indices[token_idx] = topk_idx[sorted_idx] + + def get_expert_group(self, expert_id): + """Return which group an expert belongs to.""" + return expert_id // self.experts_per_group + + def is_valid_group_selection(self, token_idx, selected_groups): + """Check if a set of selected groups is valid (exact match or tied).""" + ref_groups = set(self.ref_group_indices[token_idx].tolist()) + selected_groups_set = set(selected_groups) + + if selected_groups_set == ref_groups: + return True, "exact" + + if self.n_group > 1: + diff_groups = selected_groups_set.symmetric_difference(ref_groups) + tied_groups = self.tied_group_sets[token_idx] + + if diff_groups and diff_groups.issubset(tied_groups): + return True, "tied_groups" + + return False, "different_groups" + + def is_valid_expert_selection(self, token_idx, selected_experts): + """Check if a set of selected experts is valid (exact match or tied).""" + ref_experts = set(self.ref_expert_indices[token_idx].tolist()) + selected_experts_set = set(selected_experts) + + if selected_experts_set == ref_experts: + return True, "exact" + + # Check group-level validity + selected_groups = set(self.get_expert_group(e) for e in selected_experts) + ref_groups = set(self.ref_group_indices[token_idx].tolist()) + + # If different groups selected + if selected_groups != ref_groups: + is_valid_groups, group_reason = self.is_valid_group_selection( + token_idx, list(selected_groups) + ) + if not is_valid_groups: + # Groups are different and not tied - invalid + return False, group_reason + + # Groups are tied - now check if kernel selected correct top-k within its groups + expected_experts_in_kernel_groups = self._get_topk_experts_from_groups( + token_idx, list(selected_groups) + ) + + # Check if kernel's selection matches expected experts (exact or tied) + if selected_experts_set != expected_experts_in_kernel_groups: + # Different experts - check if they have tied scores + diff_experts = selected_experts_set.symmetric_difference( + expected_experts_in_kernel_groups + ) + biased_scores_token = self.biased_scores[token_idx] + diff_expert_scores = torch.tensor( + [biased_scores_token[e].item() for e in diff_experts] + ) + score_range = diff_expert_scores.max() - diff_expert_scores.min() + + if score_range >= self.expert_tie_threshold: + # Experts are wrong (not tied) - invalid even though groups are tied + return ( + False, + f"tied_groups_but_wrong_experts_score_diff={score_range:.6f}", + ) + + # Groups are tied and experts are correct (or acceptably tied) + return True, "tied_groups" + + # Same groups but different experts - check expert-level ties + diff_experts = selected_experts_set.symmetric_difference(ref_experts) + if diff_experts: + biased_scores_token = self.biased_scores[token_idx] + diff_expert_scores = torch.tensor( + [biased_scores_token[e].item() for e in diff_experts] + ) + score_range = diff_expert_scores.max() - diff_expert_scores.min() + + if score_range < self.expert_tie_threshold: + return True, "tied_experts" + else: + return ( + False, + f"score_diff={score_range:.6f}_threshold={self.expert_tie_threshold:.6f}", + ) + + return True, "exact" + + def _get_topk_experts_from_groups(self, token_idx, groups): + """ + Get the expected top-k experts from specified groups. + This computes what experts SHOULD be selected if these groups were chosen. + """ + # Create mask for specified groups + group_mask = torch.zeros(self.n_group, dtype=torch.float32, device=self.device) + for g in groups: + group_mask[g] = 1.0 + expert_mask = group_mask.repeat_interleave(self.experts_per_group) + + # Mask and select top-k experts + masked_biased_scores = self.biased_scores[token_idx] * expert_mask + _, topk_idx = torch.topk( + masked_biased_scores, k=self.topk, dim=-1, largest=True, sorted=True + ) + + return set(topk_idx.tolist()) + + +def validate_expert_selection(ground_truth, topk_indices_kernel, topk_values_kernel): + """Validate kernel outputs and provide detailed debug info for failures.""" + num_tokens = topk_indices_kernel.shape[0] + tokens_with_different_experts = set() + + for token_idx in range(num_tokens): + kernel_experts = topk_indices_kernel[token_idx].tolist() + ref_experts = ground_truth.ref_expert_indices[token_idx].tolist() + + # Same experts - valid + if set(kernel_experts) == set(ref_experts): + continue + + # Different experts - mark for value comparison skip + tokens_with_different_experts.add(token_idx) + + # Validate the selection + is_valid, reason = ground_truth.is_valid_expert_selection( + token_idx, kernel_experts + ) + + if not is_valid: + return False, tokens_with_different_experts + + return True, tokens_with_different_experts + + +def validate_values(ground_truth, topk_values_kernel, tokens_to_skip, data_type): + """Validate that output values match reference within tolerance.""" + # Set tolerance based on data type + if data_type == torch.bfloat16: + rtol, atol = 0.1, 0.1 + elif data_type == torch.float16: + rtol, atol = 0.05, 0.05 + else: # float32 + rtol, atol = 0.01, 0.01 + + num_tokens = topk_values_kernel.shape[0] + + # Create mask for tokens to check + tokens_to_check = torch.ones(num_tokens, dtype=torch.bool) + for token_idx in tokens_to_skip: + tokens_to_check[token_idx] = False + + if not tokens_to_check.any(): + return + + # Compare values + ref_values = ground_truth.ref_expert_values[tokens_to_check].float() + kernel_values = topk_values_kernel[tokens_to_check].float() + + try: + torch.testing.assert_close( + ref_values, + kernel_values, + rtol=rtol, + atol=atol, + ) + except AssertionError: + # Find and report first mismatch + for token_idx in range(num_tokens): + if not tokens_to_check[token_idx]: + continue + + ref_vals = ground_truth.ref_expert_values[token_idx].float() + kernel_vals = topk_values_kernel[token_idx].float() + + if not torch.allclose(ref_vals, kernel_vals, rtol=rtol, atol=atol): + diff = (kernel_vals - ref_vals).abs() + max_diff = diff.max().item() + max_diff_idx = diff.argmax().item() + + print(f"\n{'=' * 80}") + print(f"VALUE MISMATCH - Token {token_idx}") + print(f"{'=' * 80}") + print(f"Tolerance: rtol={rtol}, atol={atol}") + print(f"Max difference: {max_diff:.6f} at position {max_diff_idx}") + print(f"\nReference values: {ref_vals.tolist()}") + print(f"Kernel values: {kernel_vals.tolist()}") + print(f"Absolute diff: {diff.tolist()}") + print( + f"Expert indices: {ground_truth.ref_expert_indices[token_idx].tolist()}" + ) + break + + raise + + +@pytest.mark.parametrize("num_tokens", [1, 8, 16, 64]) +@pytest.mark.parametrize("num_experts", [256, 384]) +@pytest.mark.parametrize("topk", [1, 2, 4, 8]) +@pytest.mark.parametrize("n_group", [1, 2, 4, 8]) +@pytest.mark.parametrize("topk_group", [1, 2, 4, 8]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("bias_type", [torch.float32, torch.float16, torch.bfloat16]) +def test_dsv3_fused_routing_op( + num_tokens, num_experts, topk, n_group, topk_group, data_type, bias_type +): + """ + Test NoAuxTc kernel against reference implementation. + + Validates: + 1. Expert selection equivalence (allowing for ties) + 2. Value correctness within numerical precision tolerance + """ + + # Skip invalid configurations + if topk_group * n_group < topk or topk_group > n_group: + pytest.skip( + "Invalid configuration: topk_group * n_group < topk or topk_group > n_group" + ) + if n_group > 1: + if ( + topk > 8 + or num_experts / n_group > 32 + or num_experts / n_group * topk_group > 128 + ): + pytest.skip("Invalid configuration: exceeds kernel limits for n_group > 1") + else: + if num_experts > 384 or topk > 8: + pytest.skip("Invalid configuration: exceeds kernel limits for n_group = 1") + + # Generate random inputs + torch.manual_seed(42) + scores = torch.randn(num_tokens, num_experts, device="cuda", dtype=data_type) + bias = torch.randn(num_experts, device="cuda", dtype=bias_type) + routed_scaling_factor = 1.0 + + # Compute ground truth + ground_truth = DSv3RoutingGroundTruth( + scores.clone(), + bias.clone(), + n_group, + topk_group, + topk, + routed_scaling_factor, + data_type, + ) + + # Run kernel + topk_values = torch.empty(num_tokens, topk, device="cuda", dtype=data_type) + topk_indices = torch.zeros(num_tokens, topk, device="cuda", dtype=torch.int32) + + NoAuxTc( + scores, + bias, + n_group, + topk_group, + topk, + routed_scaling_factor, + topk_values, + topk_indices, + launch_with_pdl=True, + ) + + # Sort kernel outputs for stable comparison + sorted_vals, sorted_idx = torch.sort(topk_values, dim=-1, descending=True) + topk_indices = topk_indices.gather(1, sorted_idx) + + # Validate expert selection + all_valid, tokens_with_different_experts = validate_expert_selection( + ground_truth, topk_indices, sorted_vals + ) + + if not all_valid: + pytest.fail("Expert selection mismatch not due to acceptable ties") + + # Validate values + validate_values(ground_truth, sorted_vals, tokens_with_different_experts, data_type) From 0753095742679c0f4ec9649b0dd3a611b4ee22d2 Mon Sep 17 00:00:00 2001 From: Augusto Yao Date: Thu, 20 Nov 2025 10:47:57 +0800 Subject: [PATCH 072/130] feature: make the LSE returned by MLA support base 2 or e #2113 (#2114) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This pr adds a parameter `return_lse_base_on_e` to control the base of LSE returned by MLA. Default to `False`, which keeps the same with current implementation. If `return_lse_base_on_e` is `True`, multiply the final LSE by `loge2` to maintain consistency with the standard softmax and FA3. ## ๐Ÿ” Related Issues #2113 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Added a run-time option to control whether returned logโ€‘sumโ€‘exp (LSE) baselines are scaled by ln(2) (default: disabled). * **Bug Fixes** * Conditional scaling ensures returned LSE values are consistent when the option is enabled, improving numerical consistency. * **Chores** * The new option is exposed in public APIs and bindings and is propagated through the execution path. --------- Signed-off-by: augusto.yjh --- csrc/batch_mla_binding.cu | 3 ++- csrc/batch_mla_run.cu | 4 +++- csrc/batch_mla_sm90_binding.cu | 4 ++-- csrc/batch_mla_sm90_run.cu | 5 +++-- flashinfer/mla.py | 4 ++++ include/flashinfer/attention/mla.cuh | 17 ++++++++++++----- include/flashinfer/attention/mla_hopper.cuh | 12 ++++++++---- include/flashinfer/attention/mla_params.cuh | 1 + 8 files changed, 35 insertions(+), 15 deletions(-) diff --git a/csrc/batch_mla_binding.cu b/csrc/batch_mla_binding.cu index 6822e28b93..b39192de6a 100644 --- a/csrc/batch_mla_binding.cu +++ b/csrc/batch_mla_binding.cu @@ -31,7 +31,8 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int Array plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, - int64_t num_heads, int64_t page_size, double sm_scale); + int64_t num_heads, int64_t page_size, double sm_scale, + bool return_lse_base_on_e); TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchMLAPagedAttentionPlan); TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, BatchMLAPagedAttentionRun); diff --git a/csrc/batch_mla_run.cu b/csrc/batch_mla_run.cu index dfa2442f1b..9d950787ad 100644 --- a/csrc/batch_mla_run.cu +++ b/csrc/batch_mla_run.cu @@ -31,7 +31,8 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int Array plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, - int64_t num_heads, int64_t page_size, double sm_scale) { + int64_t num_heads, int64_t page_size, double sm_scale, + bool return_lse_base_on_e) { // q_nope: [n, num_heads, head_dim_ckv] // q_pe: [n, num_heads, head_dim_kpe] // ckv_cache: [num_pages, page_size, head_dim_ckv] @@ -112,6 +113,7 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int params.o_stride_h = o_stride_h; params.sm_scale = sm_scale; + params.return_lse_base_on_e = return_lse_base_on_e; cudaError_t status = mla::BatchMLAPagedAttention( params, plan_info.num_blks_x, plan_info.num_blks_y, stream); diff --git a/csrc/batch_mla_sm90_binding.cu b/csrc/batch_mla_sm90_binding.cu index 2e6cd1aa7d..f2af49766a 100644 --- a/csrc/batch_mla_sm90_binding.cu +++ b/csrc/batch_mla_sm90_binding.cu @@ -32,8 +32,8 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, TensorView q_nope, TensorView q_pe, TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, - int64_t num_heads, int64_t page_size, - double sm_scale ADDITIONAL_FUNC_PARAMS); + int64_t num_heads, int64_t page_size, double sm_scale, + bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS); TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchMLAPagedAttentionSM90Plan); TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, BatchMLAPagedAttentionSM90Run); diff --git a/csrc/batch_mla_sm90_run.cu b/csrc/batch_mla_sm90_run.cu index 8d6d80c223..c6670ca342 100644 --- a/csrc/batch_mla_sm90_run.cu +++ b/csrc/batch_mla_sm90_run.cu @@ -31,8 +31,8 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, TensorView q_nope, TensorView q_pe, TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, - int64_t num_heads, int64_t page_size, - double sm_scale ADDITIONAL_FUNC_PARAMS) { + int64_t num_heads, int64_t page_size, double sm_scale, + bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS) { // q_nope: [n, num_heads, head_dim_ckv] // q_pe: [n, num_heads, head_dim_kpe] // ckv_cache: [num_pages, page_size, head_dim_ckv] @@ -111,6 +111,7 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, params.kpe_stride_n = kpe_stride_n; params.o_stride_n = o_stride_n; params.o_stride_h = o_stride_h; + params.return_lse_base_on_e = return_lse_base_on_e; ADDITIONAL_PARAMS_SETTER diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 490ae7edf0..da57d94e6b 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -314,6 +314,7 @@ def run( profiler_buffer: Optional[torch.Tensor] = None, kv_len: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, + return_lse_base_on_e: bool = False, ) -> torch.Tensor: ... @overload @@ -329,6 +330,7 @@ def run( profiler_buffer: Optional[torch.Tensor] = None, kv_len: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, + return_lse_base_on_e: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: ... def run( @@ -343,6 +345,7 @@ def run( profiler_buffer: Optional[torch.Tensor] = None, kv_len: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, + return_lse_base_on_e: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Run the MLA attention computation. @@ -441,6 +444,7 @@ def run( num_heads, page_size, sm_scale, + return_lse_base_on_e, *profiler_args, ) diff --git a/include/flashinfer/attention/mla.cuh b/include/flashinfer/attention/mla.cuh index 31401ff1c5..9882cd027a 100644 --- a/include/flashinfer/attention/mla.cuh +++ b/include/flashinfer/attention/mla.cuh @@ -628,7 +628,8 @@ __device__ void DevicePersistentMergeStates( typename KTraits::IdType* merge_partial_packed_offset_end, typename KTraits::IdType* merge_partial_stride, typename KTraits::DTypeO* partial_o, float* partial_lse, typename KTraits::DTypeO* final_o, float* final_lse, - const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv& num_heads) { + const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv& num_heads, + const bool& return_lse_base_on_e) { constexpr uint32_t VEC_SIZE = 8; // partial o has data type float constexpr uint32_t NUM_THRS_PER_ROW = KTraits::HEAD_DIM_CKV / VEC_SIZE; constexpr uint32_t ROWS_PER_ITERATION = (KTraits::NUM_THREADS) / NUM_THRS_PER_ROW; @@ -661,6 +662,9 @@ __device__ void DevicePersistentMergeStates( (q * o_stride_n + r * o_stride_h + (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE)); if (final_lse) { final_lse[q * num_heads + r] = st.get_lse(); + if (return_lse_base_on_e) { + final_lse[q * num_heads + r] *= math::loge2; + } } } } @@ -672,8 +676,8 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st float (*o_frag)[8], typename KTraits::DTypeQKAccum* m, float* d, const uint32_t o_stride_n, const uint32_t o_stride_h, const uint32_t q_len, - const uint32_t packed_offset, - const uint_fastdiv& num_heads) { + const uint32_t packed_offset, const uint_fastdiv& num_heads, + const bool& return_lse_base_on_e) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV; constexpr uint32_t HEAD_DIM_CKV = KTraits::HEAD_DIM_CKV; @@ -744,6 +748,9 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4, q, r); if (lane_idx % 4 == 0 && q < q_len) { final_lse[q * num_heads + r] = math::ptx_log2(d[j]) + float(m[j]); + if (return_lse_base_on_e) { + final_lse[q * num_heads + r] *= math::loge2; + } } } } @@ -967,7 +974,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe final_lse ? final_lse + q_indptr * num_heads : nullptr, (partial_indptr == -1) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV, (partial_indptr == -1) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n, - o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads); + o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads, params.return_lse_base_on_e); } auto grid = cg::this_grid(); @@ -978,7 +985,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe params.merge_packed_offset_start, params.merge_packed_offset_end, params.merge_partial_packed_offset_start, params.merge_partial_packed_offset_end, params.merge_partial_stride, partial_o, partial_lse, final_o, final_lse, o_stride_n, - o_stride_h, num_heads); + o_stride_h, num_heads, params.return_lse_base_on_e); } #define DISPATCH_SMEM_CONFIG(smem_limit_per_sm, NUM_STAGES, CTA_TILE_KV, QK_SHARD, ...) \ diff --git a/include/flashinfer/attention/mla_hopper.cuh b/include/flashinfer/attention/mla_hopper.cuh index efcf660fbc..65b1d60573 100644 --- a/include/flashinfer/attention/mla_hopper.cuh +++ b/include/flashinfer/attention/mla_hopper.cuh @@ -455,7 +455,8 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st float* partial_lse, float(*o_frag), float* m, float* d, const uint32_t o_stride_n, const uint32_t o_stride_h, const uint32_t q_len, const uint32_t packed_offset, - const uint_fastdiv& num_heads) { + const uint_fastdiv& num_heads, + const bool& return_lse_base_on_e) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV; constexpr uint32_t HEAD_DIM_CKV = KTraits::HEAD_DIM_CKV; @@ -543,6 +544,9 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4, q, r); if (lane_idx % 4 == 0 && q < q_len) { final_lse[q * num_heads + r] = math::ptx_log2(d[j]) + float(m[j]); + if (return_lse_base_on_e) { + final_lse[q * num_heads + r] *= math::loge2; + } } } } @@ -796,7 +800,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop final_lse ? final_lse + q_indptr * num_heads : nullptr, (partial_indptr == -1) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV, (partial_indptr == -1) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n, - o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads); + o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads, params.return_lse_base_on_e); PROFILER_EVENT_END(variant, ProfileEventType::kWriteO); __syncthreads(); } @@ -936,7 +940,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop final_lse ? final_lse + q_indptr * num_heads : nullptr, (partial_indptr == -1) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV, (partial_indptr == -1) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n, - o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads); + o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads, params.return_lse_base_on_e); PROFILER_EVENT_END(variant, ProfileEventType::kWriteO); __syncthreads(); } @@ -953,7 +957,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop params.merge_packed_offset_start, params.merge_packed_offset_end, params.merge_partial_packed_offset_start, params.merge_partial_packed_offset_end, params.merge_partial_stride, partial_o, partial_lse, final_o, final_lse, o_stride_n, - o_stride_h, num_heads); + o_stride_h, num_heads, params.return_lse_base_on_e); PROFILER_EVENT_END(variant, ProfileEventType::kSplitK); } diff --git a/include/flashinfer/attention/mla_params.cuh b/include/flashinfer/attention/mla_params.cuh index ff5d168ba2..6da1ed7f53 100644 --- a/include/flashinfer/attention/mla_params.cuh +++ b/include/flashinfer/attention/mla_params.cuh @@ -71,6 +71,7 @@ struct MLAParams { uint32_t o_stride_h; float sm_scale; + bool return_lse_base_on_e; }; }; // namespace flashinfer From 76eea79b817ffbb154dbb730023597261bd11894 Mon Sep 17 00:00:00 2001 From: qsang-nv <200703406+qsang-nv@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:34:29 +0800 Subject: [PATCH 073/130] update xqa license (#2117) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Update xqa license based on https://github.com/NVIDIA/TensorRT-LLM/pull/8807 ## ๐Ÿ” Related Issues https://github.com/flashinfer-ai/flashinfer/issues/1977 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Updated project licensing to Apache License 2.0 with extended copyright years through 2025. โœ๏ธ Tip: You can customize this high-level summary in your review settings. Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --- csrc/xqa/barriers.cuh | 21 +++++++++++++-------- csrc/xqa/cuda_hint.cuh | 21 +++++++++++++-------- csrc/xqa/defines.h | 21 +++++++++++++-------- csrc/xqa/gmma.cuh | 21 +++++++++++++-------- csrc/xqa/gmma_impl.cuh | 21 +++++++++++++-------- csrc/xqa/ldgsts.cuh | 21 +++++++++++++-------- csrc/xqa/mha.cu | 21 +++++++++++++-------- csrc/xqa/mha.h | 21 +++++++++++++-------- csrc/xqa/mhaUtils.cuh | 21 +++++++++++++-------- csrc/xqa/mha_sm90.cu | 21 +++++++++++++-------- csrc/xqa/mha_stdheaders.cuh | 21 +++++++++++++-------- csrc/xqa/mla_sm120.cu | 21 +++++++++++++-------- csrc/xqa/mma.cuh | 21 +++++++++++++-------- csrc/xqa/platform.h | 21 +++++++++++++-------- csrc/xqa/specDec.h | 21 +++++++++++++-------- csrc/xqa/tma.h | 21 +++++++++++++-------- csrc/xqa/utils.cuh | 21 +++++++++++++-------- csrc/xqa/utils.h | 21 +++++++++++++-------- 18 files changed, 234 insertions(+), 144 deletions(-) diff --git a/csrc/xqa/barriers.cuh b/csrc/xqa/barriers.cuh index c65b755294..ad5a77a72e 100644 --- a/csrc/xqa/barriers.cuh +++ b/csrc/xqa/barriers.cuh @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/cuda_hint.cuh b/csrc/xqa/cuda_hint.cuh index d6e2af86eb..8007e4c3d4 100644 --- a/csrc/xqa/cuda_hint.cuh +++ b/csrc/xqa/cuda_hint.cuh @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/defines.h b/csrc/xqa/defines.h index ca8589d808..3794708a3b 100644 --- a/csrc/xqa/defines.h +++ b/csrc/xqa/defines.h @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/gmma.cuh b/csrc/xqa/gmma.cuh index d1b2547fcd..a62f34a434 100644 --- a/csrc/xqa/gmma.cuh +++ b/csrc/xqa/gmma.cuh @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/gmma_impl.cuh b/csrc/xqa/gmma_impl.cuh index b9515ddea9..6c47fa0bf9 100644 --- a/csrc/xqa/gmma_impl.cuh +++ b/csrc/xqa/gmma_impl.cuh @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/ldgsts.cuh b/csrc/xqa/ldgsts.cuh index 779a13429c..86ac6b7f7b 100644 --- a/csrc/xqa/ldgsts.cuh +++ b/csrc/xqa/ldgsts.cuh @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index 016a4f982a..8699510dcc 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "cuda_hint.cuh" diff --git a/csrc/xqa/mha.h b/csrc/xqa/mha.h index 872cd45059..9681e5f9e4 100644 --- a/csrc/xqa/mha.h +++ b/csrc/xqa/mha.h @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/mhaUtils.cuh b/csrc/xqa/mhaUtils.cuh index 869862f204..7fd53f9344 100644 --- a/csrc/xqa/mhaUtils.cuh +++ b/csrc/xqa/mhaUtils.cuh @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu index 9b751817c5..3d77535909 100644 --- a/csrc/xqa/mha_sm90.cu +++ b/csrc/xqa/mha_sm90.cu @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "cuda_hint.cuh" diff --git a/csrc/xqa/mha_stdheaders.cuh b/csrc/xqa/mha_stdheaders.cuh index c76e759c32..353a8bbd9c 100644 --- a/csrc/xqa/mha_stdheaders.cuh +++ b/csrc/xqa/mha_stdheaders.cuh @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/mla_sm120.cu b/csrc/xqa/mla_sm120.cu index 30863edced..42a37a9b9f 100644 --- a/csrc/xqa/mla_sm120.cu +++ b/csrc/xqa/mla_sm120.cu @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "defines.h" diff --git a/csrc/xqa/mma.cuh b/csrc/xqa/mma.cuh index fac96843aa..c8e425f213 100644 --- a/csrc/xqa/mma.cuh +++ b/csrc/xqa/mma.cuh @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/platform.h b/csrc/xqa/platform.h index cb1a9e7c58..797d4234ab 100644 --- a/csrc/xqa/platform.h +++ b/csrc/xqa/platform.h @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/specDec.h b/csrc/xqa/specDec.h index 7a4131a59c..22e1e9c566 100644 --- a/csrc/xqa/specDec.h +++ b/csrc/xqa/specDec.h @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/tma.h b/csrc/xqa/tma.h index 5cf67238a2..d0137ffefc 100644 --- a/csrc/xqa/tma.h +++ b/csrc/xqa/tma.h @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/utils.cuh b/csrc/xqa/utils.cuh index 6302d4e20b..a9ac1805b9 100644 --- a/csrc/xqa/utils.cuh +++ b/csrc/xqa/utils.cuh @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once diff --git a/csrc/xqa/utils.h b/csrc/xqa/utils.h index d685b72d0d..45e18d3a2b 100644 --- a/csrc/xqa/utils.h +++ b/csrc/xqa/utils.h @@ -1,13 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights - * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #pragma once From af25b45c0579ba73f1d45da3be8a9f440469f5ec Mon Sep 17 00:00:00 2001 From: qsang-nv <200703406+qsang-nv@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:36:45 +0800 Subject: [PATCH 074/130] add tensor scale input for xqa (#2110) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Attention ops now accept tensor-based per-head scaling (q/kv) in C++ and Python paths, enabling dynamic or per-tensor quantization scales. * Python APIs and docs updated to accept float or tensor scales. * **Tests** * Batch-decode tests adjusted to use per-sequence cache/block sizing for more accurate memory dimensioning. * **Documentation** * Docstrings updated to describe tensor-or-scalar scale inputs. --------- Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --- csrc/flashinfer_xqa_binding.cu | 12 +++-- csrc/xqa/mha.cu | 48 ++++++++++-------- csrc/xqa/mha.h | 53 ++++++++++---------- csrc/xqa/mha_sm90.cu | 51 ++++++++++--------- csrc/xqa/mla_sm120.cu | 37 +++++++------- csrc/xqa/xqa_wrapper.cu | 36 +++++++++---- flashinfer/decode.py | 45 +++++------------ flashinfer/xqa.py | 47 +++++++++-------- tests/attention/test_xqa_batch_decode.py | 4 +- tests/attention/test_xqa_mla_batch_decode.py | 4 +- 10 files changed, 174 insertions(+), 163 deletions(-) diff --git a/csrc/flashinfer_xqa_binding.cu b/csrc/flashinfer_xqa_binding.cu index e21eb3a73d..dc06614763 100644 --- a/csrc/flashinfer_xqa_binding.cu +++ b/csrc/flashinfer_xqa_binding.cu @@ -17,20 +17,24 @@ #include "tvm_ffi_utils.h" #if MLA_WRAPPER -void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q, +void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, + tvm::ffi::Optional qScaleTensor, TensorView output, TensorView q, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, - TensorView semaphores, TensorView scratch, bool enable_pdl); + tvm::ffi::Optional kvScaleTensor, TensorView semaphores, + TensorView scratch, bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla); #else void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, - int64_t slidingWinSize, double qScale, TensorView output, double rcpOutScale, - TensorView q, tvm::ffi::Optional attentionSinks, TensorView kCacheVLLM, + int64_t slidingWinSize, double qScale, tvm::ffi::Optional qScaleTensor, + TensorView output, double rcpOutScale, TensorView q, + tvm::ffi::Optional attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, + tvm::ffi::Optional kvScaleTensor, #if SPEC_DEC int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, #endif diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index 8699510dcc..af61dc0034 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -1283,7 +1283,7 @@ CUBIN_EXPORT __global__ #if SLIDING_WINDOW uint32_t slidingWinSize, #endif - float qScale, + float qScale, float const* qScalePtr, OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] #if LOW_PREC_OUTPUT float rcpOutScale, @@ -1305,10 +1305,13 @@ CUBIN_EXPORT __global__ BeamSearchParams const beamSearchParams, #endif #endif - uint32_t const batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t const batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head, uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) { + + float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale; + float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale; assert(allowMultiBlockMode || gridDim.x == 1); bool const isMultiBlock = allowMultiBlockMode && (gridDim.x != 1); uint32_t const nbSubSeqPerSeq = allowMultiBlockMode ? gridDim.x : 1; @@ -1507,7 +1510,7 @@ CUBIN_EXPORT __global__ }; if (warpIdx.z == 0) { float const qkScale = - qScale * (isKVCacheQuantized ? kvCacheScale : 1.f) * + qScaleValue * (isKVCacheQuantized ? kvCacheScaleValue : 1.f) * rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. CircIdx idxCurrSMemKBuf{nbKBuffers - 1}; auto const getSMemKTile = [&](uint32_t idx) -> SharedMem::KSmemBuffer& { @@ -2160,7 +2163,7 @@ CUBIN_EXPORT __global__ } } - float voScale = (isKVCacheQuantized ? kvCacheScale : 1.F); + float voScale = (isKVCacheQuantized ? kvCacheScaleValue : 1.F); if (seqIterInit < nbSeqIters) { // otherwise rcpRowSum will be NAN. // The attention sinks are moved to the multi-block reduction part if the multi-block is // enabled. @@ -2398,7 +2401,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( #if SLIDING_WINDOW uint32_t slidingWinSize, #endif - float qScale, + float qScale, float const* qScalePtr, OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] #if LOW_PREC_OUTPUT float rcpOutScale, @@ -2413,8 +2416,8 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( #if BEAM_WIDTH > 1 BeamSearchParams const beamSearchParams, #endif - uint32_t const batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t const batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head, uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) { #if SPEC_DEC @@ -2425,7 +2428,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( #if SLIDING_WINDOW slidingWinSize, #endif - qScale, output, + qScale, qScalePtr, output, #if LOW_PREC_OUTPUT rcpOutScale, #endif @@ -2437,8 +2440,8 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( #if BEAM_WIDTH > 1 beamSearchParams, #endif - batchSize, kvCacheScale, kv_stride_page, kv_stride_token, kv_stride_head, - semaphores, scratch); + batchSize, kvCacheScale, kvScalePtr, kv_stride_page, kv_stride_token, + kv_stride_head, semaphores, scratch); } #else static constexpr auto kernel_mha = kernel_mha_impl; @@ -2450,7 +2453,7 @@ void launchMHA( #if SLIDING_WINDOW uint32_t slidingWinSize, #endif - float qScale, OutputHead* output, + float qScale, float const* qScalePtr, OutputHead* output, #if LOW_PREC_OUTPUT float rcpOutScale, #endif @@ -2471,8 +2474,8 @@ void launchMHA( #if BEAM_WIDTH > 1 BeamSearchParams const& beamSearchParams, #endif - uint32_t batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif @@ -2537,7 +2540,7 @@ void launchMHA( #if SLIDING_WINDOW slidingWinSize, #endif - qScale, output, + qScale, qScalePtr, output, #if LOW_PREC_OUTPUT rcpOutScale, #endif @@ -2549,8 +2552,8 @@ void launchMHA( #if BEAM_WIDTH > 1 beamSearchParams, #endif - batchSize, kvCacheScale, stride_page_in_heads, stride_token_in_heads, - stride_head_in_heads, semaphores, scratch); + batchSize, kvCacheScale, kvScalePtr, stride_page_in_heads, + stride_token_in_heads, stride_head_in_heads, semaphores, scratch); checkCuda(cudaPeekAtLastError()); #endif // USE_INPUT_KV } @@ -2566,14 +2569,14 @@ static uint32_t configureKernel() { static uint32_t const hostSmemSize = configureKernel(); void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, - float qScale, OutputHead* output, + float qScale, float const* qScalePtr, OutputHead* output, #if LOW_PREC_OUTPUT float rcpOutScale, #endif InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float kvCacheScale, + float kvCacheScale, float const* kvScalePtr, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif @@ -2612,7 +2615,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 #if SLIDING_WINDOW slidingWinSize, #endif - qScale, output, + qScale, qScalePtr, output, #if LOW_PREC_OUTPUT rcpOutScale, #endif @@ -2620,8 +2623,9 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 #if SPEC_DEC mask, #endif - attentionSinks, cacheList, batchSize, kvCacheScale, stride_page_in_heads, - stride_token_in_heads, stride_head_in_heads, semaphores, scratch); + attentionSinks, cacheList, batchSize, kvCacheScale, kvScalePtr, + stride_page_in_heads, stride_token_in_heads, stride_head_in_heads, semaphores, + scratch); checkCuda(cudaPeekAtLastError()); } #endif diff --git a/csrc/xqa/mha.h b/csrc/xqa/mha.h index 9681e5f9e4..d7ab1c452c 100644 --- a/csrc/xqa/mha.h +++ b/csrc/xqa/mha.h @@ -98,7 +98,7 @@ void launchMHA( #if SLIDING_WINDOW uint32_t slidingWinSize, #endif - float qScale, OutputHead* output, + float qScale, float const* qScalePtr, OutputHead* output, #if LOW_PREC_OUTPUT float rcpOutScale, #endif @@ -119,8 +119,8 @@ void launchMHA( #if BEAM_WIDTH > 1 BeamSearchParams const& beamSearchParams, #endif - uint32_t batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif @@ -128,14 +128,14 @@ void launchMHA( uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream); void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, - float qScale, OutputHead* output, + float qScale, float const* qScalePtr, OutputHead* output, #if LOW_PREC_OUTPUT float rcpOutScale, #endif InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float kvCacheScale, + float kvCacheScale, float const* kvScalePtr, #if SPEC_DEC uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif @@ -148,7 +148,7 @@ void launchHopperF8MHA( #if SLIDING_WINDOW uint32_t slidingWinSize, #endif - float qScale, OutputHead* output, + float qScale, float const* qScalePtr, OutputHead* output, #if LOW_PREC_OUTPUT float rcpOutScale, #endif @@ -169,53 +169,52 @@ void launchHopperF8MHA( #if BEAM_WIDTH > 1 BeamSearchParams const& beamSearchParams, #endif - uint32_t batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream); -void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, - uint32_t slidingWinSize, float qScale, OutputHead* output, +void launchHopperF8MHAFlashInfer( + uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, float qScale, + float const* qScalePtr, OutputHead* output, #if LOW_PREC_OUTPUT - float rcpOutScale, + float rcpOutScale, #endif - InputHead const* q, float const* attentionSinks, - GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, - KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, - uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, + InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, + GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, + uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, float const* kvScalePtr, #if SPEC_DEC - uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, + uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif - uint32_t* semaphores, void* scratch, bool enable_pdl, - uint64_t kv_stride_page, uint64_t kv_stride_token, - uint64_t kv_stride_head, cudaStream_t stream); + uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, + uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream); void launchMLA( cudaDeviceProp const& prop, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed - float qScale, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM, - GMemCacheHead* vCacheVLLM, + float qScale, float const* qScalePtr, OutputHead* output, InputHead const* q, + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] // (Layout 0) or [batchSize][maxNbPagesPerSeq] (Layout 1) - uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream); void launchMLAFlashInfer( uint32_t multiProcessorCount, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed - float qScale, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM, - GMemCacheHead* vCacheVLLM, + float qScale, float const* qScalePtr, OutputHead* output, InputHead const* q, + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, // device pointer. shape: // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or // [batchSize][maxNbPagesPerSeq] (Layout 1) - uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream); diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu index 3d77535909..06938edd91 100644 --- a/csrc/xqa/mha_sm90.cu +++ b/csrc/xqa/mha_sm90.cu @@ -612,7 +612,7 @@ __launch_bounds__(128 * 3) #if SLIDING_WINDOW uint32_t const slidingWinSize, #endif - float const qScale, + float const qScale, float const* qScalePtr, OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] #if LOW_PREC_OUTPUT float rcpOutScale, @@ -630,8 +630,8 @@ __launch_bounds__(128 * 3) #if USE_BEAM_SEARCH BeamSearchParams const beamSearchParams, #endif - uint32_t const batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t const batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. __grid_constant__ CUtensorMap const tensorMapVLLMK, __grid_constant__ CUtensorMap const tensorMapVLLMV, #if SPEC_DEC @@ -640,6 +640,8 @@ __launch_bounds__(128 * 3) uint32_t* __restrict__ const semaphores = nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)] void* __restrict__ const scratch = nullptr) { + float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale; + float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) && \ (IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1 uint32_t const idxReq = blockIdx.z / nbKHeads; @@ -777,7 +779,7 @@ __launch_bounds__(128 * 3) } float const qkScale = - qScale * (isKVCacheQuantized ? kvCacheScale : 1.f) * + qScaleValue * (isKVCacheQuantized ? kvCacheScaleValue : 1.f) * rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. uint32_t const warpRank = warpIdx.x; @@ -966,7 +968,7 @@ __launch_bounds__(128 * 3) #else constexpr float oScale = 1.F; #endif - float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale : 1.f) * oScale; + float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScaleValue : 1.f) * oScale; Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction. gmma::fence(); @@ -1320,7 +1322,7 @@ __launch_bounds__(128 * 3) headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; IOHead const& inKHead = qkv[inputKHeadOffset]; uint32_t const lane = laneId(); - float const rcpKScale = 1.F / kvCacheScale; + float const rcpKScale = 1.F / kvCacheScaleValue; #if ROPE_STYLE == 0 constexpr bool isNeox = false; auto const pairs = @@ -1379,7 +1381,7 @@ __launch_bounds__(128 * 3) (headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; IOHead const& inVHead = qkv[inputVHeadOffset]; uint32_t const lane = laneId(); - float const rcpVScale = 1.F / kvCacheScale; + float const rcpVScale = 1.F / kvCacheScaleValue; constexpr bool isNeox = false; auto const pairs = loadHead(inVHead, lane) * rcpVScale; @@ -2913,7 +2915,7 @@ void launchHopperF8MHA( #if SLIDING_WINDOW uint32_t slidingWinSize, #endif - float qScale, OutputHead* output, + float qScale, float const* qScalePtr, OutputHead* output, #if LOW_PREC_OUTPUT float rcpOutScale, #endif @@ -2934,8 +2936,8 @@ void launchHopperF8MHA( #if USE_BEAM_SEARCH BeamSearchParams const& beamSearchParams, #endif - uint32_t batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif @@ -3005,7 +3007,7 @@ void launchHopperF8MHA( #if SLIDING_WINDOW slidingWinSize, #endif - qScale, output, + qScale, qScalePtr, output, #if LOW_PREC_OUTPUT rcpOutScale, #endif @@ -3021,7 +3023,7 @@ void launchHopperF8MHA( #if USE_BEAM_SEARCH beamSearchParams, #endif - batchSize, kvCacheScale, tensorMapVLLMK, tensorMapVLLMV, + batchSize, kvCacheScale, kvScalePtr, tensorMapVLLMK, tensorMapVLLMV, #if SPEC_DEC specDecParams, #endif @@ -3039,21 +3041,20 @@ static uint32_t configureKernel() { static uint32_t const hostSmemSize = configureKernel(); -void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, - uint32_t slidingWinSize, float qScale, OutputHead* output, +void launchHopperF8MHAFlashInfer( + uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, float qScale, + float const* qScalePtr, OutputHead* output, #if LOW_PREC_OUTPUT - float rcpOutScale, + float rcpOutScale, #endif - InputHead const* q, float const* attentionSinks, - GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, - KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, - uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, + InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM, + GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, + uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, float const* kvScalePtr, #if SPEC_DEC - uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, + uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, #endif - uint32_t* semaphores, void* scratch, bool enable_pdl, - uint64_t kv_stride_page, uint64_t kv_stride_token, - uint64_t kv_stride_head, cudaStream_t stream) { + uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, + uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) { uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { float const factor = 0.25f; return mha::min( @@ -3096,12 +3097,12 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads #if SLIDING_WINDOW slidingWinSize, #endif - qScale, output, + qScale, qScalePtr, output, #if LOW_PREC_OUTPUT rcpOutScale, #endif q, attentionSinks, cacheList, batchSize, kvCacheScale, - tensorMapVLLMK, tensorMapVLLMV, + kvScalePtr, tensorMapVLLMK, tensorMapVLLMV, #if SPEC_DEC specDecParams, #endif diff --git a/csrc/xqa/mla_sm120.cu b/csrc/xqa/mla_sm120.cu index 42a37a9b9f..495d9e94d0 100644 --- a/csrc/xqa/mla_sm120.cu +++ b/csrc/xqa/mla_sm120.cu @@ -1554,16 +1554,18 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha __grid_constant__ CUtensorMap const tensorMapQ, // MhaIOHead[nbQHeads * totalNbInputTokens], __grid_constant__ CUtensorMap const tensorMapK, // with box=64 for the least significant dim __grid_constant__ CUtensorMap const tensorMapV, // with box=128 for the least significant dim - float const qScale, + float const qScale, float const* qScalePtr, OutputHead* __restrict__ const output, // [totalNbIntputTokens][nbQHeads] - KVCacheList const cacheList, uint32_t const batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + KVCacheList const cacheList, uint32_t const batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. Vec* __restrict__ const cgaXBuf, // [totalNbInputTokens][maxNbSubSeq] uint32_t* __restrict__ const semaphores = nullptr, // [totalNbInputTokens] PartialResult* __restrict__ const partialResults = nullptr) // [totalNbInputTokens][maxNbSubSeq] { + float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale; + float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale; assert(blockDim.x == 32 * 12 && blockDim.y == 1 && blockDim.z == 1); extern __shared__ char smemBuf[]; uint32_t const warpRank = makeWarpUniform(this_warp(), threadIdx.x / warp_size); @@ -1594,8 +1596,9 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha uint32_t const ctaRank = clusterCtaRank(); bool const isProducer = (ctaRank < nbProducerCtasPerCga); - KernelArgs const args{tensorMapQ, tensorMapK, tensorMapV, qScale, output, cacheList, - batchSize, kvCacheScale, cgaXBuf, semaphores, partialResults}; + KernelArgs const args{tensorMapQ, tensorMapK, tensorMapV, qScaleValue, + output, cacheList, batchSize, kvCacheScaleValue, + cgaXBuf, semaphores, partialResults}; if (isProducer) { Producer{args, @@ -1654,13 +1657,13 @@ CUtensorMap makeTensorMapForQ(void const* addr, CUtensorMapDataType_enum dataTyp void launchMLA( cudaDeviceProp const& prop, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed - float qScale, OutputHead* output, InputHead const* q, + float qScale, float const* qScalePtr, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout KVCachePageIndex const* kvCachePageList, // device pointer. shape: // [batchSize][maxNbPagesPerSeq] (Layout 1) - uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) { #if IS_MLA @@ -1727,9 +1730,9 @@ void launchMLA( uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z; auto const cgaXBuf = static_cast*>(scratch); auto const partialResults = reinterpret_cast(cgaXBuf + nbCgas); - cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, tensorMapQ, tensorMapK, - tensorMapV, qScale, output, cacheList, batchSize, - kvCacheScale, cgaXBuf, semaphores, partialResults); + cudaError_t const err = cudaLaunchKernelEx( + &launchCfg, &kernel_mha, tensorMapQ, tensorMapK, tensorMapV, qScale, qScalePtr, output, + cacheList, batchSize, kvCacheScale, kvScalePtr, cgaXBuf, semaphores, partialResults); #else KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; static_assert(!usePagedKVCache); @@ -1775,13 +1778,13 @@ static uint32_t const hostSmemSize = configureKernel(); void launchMLAFlashInfer( uint32_t multiProcessorCount, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed - float qScale, OutputHead* output, InputHead const* q, + float qScale, float const* qScalePtr, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout KVCachePageIndex const* kvCachePageList, // device pointer. shape: // [batchSize][maxNbPagesPerSeq] (Layout 1) - uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, - float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache. + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale, + float const* kvScalePtr, // Same scale for K and V cache. Used only for int8/fp8 KV cache. uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page, uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) { #if IS_MLA @@ -1834,9 +1837,9 @@ void launchMLAFlashInfer( uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z; auto const cgaXBuf = static_cast*>(scratch); auto const partialResults = reinterpret_cast(cgaXBuf + nbCgas); - cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, tensorMapQ, tensorMapK, - tensorMapV, qScale, output, cacheList, batchSize, - kvCacheScale, cgaXBuf, semaphores, partialResults); + cudaError_t const err = cudaLaunchKernelEx( + &launchCfg, &kernel_mha, tensorMapQ, tensorMapK, tensorMapV, qScale, qScalePtr, output, + cacheList, batchSize, kvCacheScale, kvScalePtr, cgaXBuf, semaphores, partialResults); checkCuda(err); #endif } diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index 796a4b33ef..3f9d637b42 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -20,35 +20,42 @@ using tvm::ffi::Optional; #if MLA_WRAPPER -void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q, - TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, - int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, +void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, Optional qScaleTensor, + TensorView output, TensorView q, TensorView kCacheVLLM, TensorView vCacheVLLM, + TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, + int64_t batchSize, double kvCacheScale, Optional kvScaleTensor, TensorView semaphores, TensorView scratch, bool enable_pdl) { auto stream = get_stream(output.device()); - + float const* qScalePtr = qScaleTensor.has_value() + ? reinterpret_cast(qScaleTensor.value().data_ptr()) + : nullptr; + float const* kvScalePtr = kvScaleTensor.has_value() + ? reinterpret_cast(kvScaleTensor.value().data_ptr()) + : nullptr; // Extract strides from TensorView (in elements, not bytes) uint64_t kv_stride_page = kCacheVLLM.stride(0); uint64_t kv_stride_token = kCacheVLLM.stride(-2); uint64_t kv_stride_head = kCacheVLLM.stride(-3); - launchMLAFlashInfer(multiProcessorCount, 1, qScale, + launchMLAFlashInfer(multiProcessorCount, 1, qScale, qScalePtr, reinterpret_cast(output.data_ptr()), reinterpret_cast(q.data_ptr()), reinterpret_cast(kCacheVLLM.data_ptr()), reinterpret_cast(vCacheVLLM.data_ptr()), reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, - kvCacheScale, reinterpret_cast(semaphores.data_ptr()), + kvCacheScale, kvScalePtr, reinterpret_cast(semaphores.data_ptr()), reinterpret_cast(scratch.data_ptr()), enable_pdl, kv_stride_page, kv_stride_token, kv_stride_head, stream); } #else void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, - int64_t slidingWinSize, double qScale, TensorView output, double rcpOutScale, - TensorView q, Optional attentionSinks, TensorView kCacheVLLM, - TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, - TensorView seqLen, int64_t batchSize, double kvCacheScale, + int64_t slidingWinSize, double qScale, Optional qScaleTensor, + TensorView output, double rcpOutScale, TensorView q, + Optional attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM, + TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, + int64_t batchSize, double kvCacheScale, Optional kvScaleTensor, #if SPEC_DEC int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, #endif @@ -57,6 +64,12 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK float const* attentionSinksPtr = attentionSinks.has_value() ? reinterpret_cast(attentionSinks.value().data_ptr()) : nullptr; + float const* qScalePtr = qScaleTensor.has_value() + ? reinterpret_cast(qScaleTensor.value().data_ptr()) + : nullptr; + float const* kvScalePtr = kvScaleTensor.has_value() + ? reinterpret_cast(kvScaleTensor.value().data_ptr()) + : nullptr; auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer; // Extract strides from TensorView (in elements, not bytes) @@ -64,7 +77,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK uint64_t kv_stride_token = kCacheVLLM.stride(-3); uint64_t kv_stride_head = kCacheVLLM.stride(-2); - mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, + mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, qScalePtr, reinterpret_cast(output.data_ptr()), #if LOW_PREC_OUTPUT rcpOutScale, @@ -74,6 +87,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK reinterpret_cast(vCacheVLLM.data_ptr()), reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, kvCacheScale, + kvScalePtr, #if SPEC_DEC qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), reinterpret_cast(mask.data_ptr()), diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 4f4e8b0215..af8dda0345 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2175,11 +2175,6 @@ def trtllm_batch_decode_with_kv_cache( ) if backend == "xqa": - # TODO(Siyuan): support device scale factors, which was removed in #2033 - if not isinstance(bmm1_scale, float): - bmm1_scale = bmm1_scale.item() - if not isinstance(bmm2_scale, float): - bmm2_scale = bmm2_scale.item() # xqa backend doesn't support nvfp4 output if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): raise ValueError("xqa backend does not support nvfp4 output") @@ -2344,8 +2339,8 @@ def xqa_batch_decode_with_kv_cache( block_tables: torch.Tensor, seq_lens: torch.Tensor, max_seq_len: int, - bmm1_scale: float, - bmm2_scale: float, + bmm1_scale: Union[float, torch.Tensor] = 1.0, + bmm2_scale: Union[float, torch.Tensor] = 1.0, window_left: int = -1, out: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, @@ -2378,10 +2373,10 @@ def xqa_batch_decode_with_kv_cache( max_seq_len : int max sequence length for kv_cache - bmm1_scale : float + bmm1_scale : Union[float, torch.Tensor] fused scale for bmm1 input. - bmm2_scale : float + bmm2_scale : Union[float, torch.Tensor] fused scale for bmm2 input. window_left : int = -1 @@ -2428,13 +2423,6 @@ def xqa_batch_decode_with_kv_cache( sm_count = get_device_sm_count(query.device) - bmm1_scale = ( - bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale - ) - bmm2_scale = ( - bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale - ) - # Extract shape parameters based on layout if kv_layout == "NHD": # NHD: [num_pages, page_size, num_kv_heads, head_dim] @@ -2590,12 +2578,12 @@ def trtllm_batch_decode_with_kv_cache_mla( backend = ( "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" ) + if isinstance(bmm1_scale, torch.Tensor): + assert bmm1_scale.dtype == torch.float32 + bmm1_scale = bmm1_scale * log2e + if isinstance(bmm2_scale, torch.Tensor): + assert bmm2_scale.dtype == torch.float32 if backend == "xqa": - # TODO(Siyuan): support device scale factors, which was removed in #2033 - if not isinstance(bmm1_scale, float): - bmm1_scale = bmm1_scale.item() - if not isinstance(bmm2_scale, float): - bmm2_scale = bmm2_scale.item() if ( get_compute_capability(query.device)[0] != 12 or query.dtype != torch.float8_e4m3fn @@ -2662,12 +2650,6 @@ def trtllm_batch_decode_with_kv_cache_mla( "out", ) - if isinstance(bmm1_scale, torch.Tensor): - assert bmm1_scale.dtype == torch.float32 - bmm1_scale = bmm1_scale * log2e - if isinstance(bmm2_scale, torch.Tensor): - assert bmm2_scale.dtype == torch.float32 - run_func( out, None, # fp4 output not supported in wrapper api yet. @@ -2706,9 +2688,8 @@ def xqa_batch_decode_with_kv_cache_mla( seq_lens: torch.Tensor, max_seq_len: int, out: Optional[torch.Tensor] = None, - # TODO(Siyuan): support device scale factors, which was removed in #2033 - bmm1_scale: Optional[float] = 1.0, - bmm2_scale: Optional[float] = 1.0, + bmm1_scale: Union[float, torch.Tensor] = 1.0, + bmm2_scale: Union[float, torch.Tensor] = 1.0, sinks: Optional[List[torch.Tensor]] = None, enable_pdl: bool = None, ) -> torch.Tensor: @@ -2724,8 +2705,8 @@ def xqa_batch_decode_with_kv_cache_mla( seq_lens: query_len max_seq_len: max sequence length for kv_cache out: output tensor, if not provided, will be allocated internally - bmm1_scale: fused scale for mla bmm1 input. - bmm2_scale: fused scale for mla bmm2 input. + bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor. + bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor. sinks: additional value per head in the denominator of the softmax. Note: diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index dbf80e7b11..bbc4832aac 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -16,7 +16,7 @@ import functools from types import SimpleNamespace -from typing import Optional +from typing import Optional, Union import torch from .jit.xqa import gen_xqa_module, gen_xqa_module_mla @@ -59,7 +59,7 @@ def xqa( sm_count: int, num_kv_heads: int, sliding_win_size: int, - q_scale: float, + q_scale: Union[float, torch.Tensor], output: torch.Tensor, rcp_out_scale: float, q: torch.Tensor, @@ -70,7 +70,7 @@ def xqa( max_seq_len: int, seq_lens: torch.Tensor, batch_size: int, - kv_scale: float, + kv_scale: Union[float, torch.Tensor], semaphores: torch.Tensor, workspace_buffer: torch.Tensor, enable_pdl: bool, @@ -80,7 +80,8 @@ def xqa( sm_count, num_kv_heads, sliding_win_size, - q_scale, + 1.0 if isinstance(q_scale, torch.Tensor) else q_scale, + None if isinstance(q_scale, float) else q_scale, output, rcp_out_scale, q, @@ -91,7 +92,8 @@ def xqa( max_seq_len, seq_lens, batch_size, - kv_scale, + 1.0 if isinstance(kv_scale, torch.Tensor) else kv_scale, + None if isinstance(kv_scale, float) else kv_scale, semaphores, workspace_buffer, enable_pdl, @@ -105,7 +107,7 @@ def _fake_xqa( sm_count: int, num_kv_heads: int, sliding_win_size: int, - q_scale: float, + q_scale: Union[float, torch.Tensor], output: torch.Tensor, rcp_out_scale: float, q: torch.Tensor, @@ -116,9 +118,10 @@ def _fake_xqa( max_seq_len: int, seq_lens: torch.Tensor, batch_size: int, - kv_scale: float, + kv_scale: Union[float, torch.Tensor], semaphores: torch.Tensor, workspace_buffer: torch.Tensor, + enable_pdl: bool, ) -> None: pass @@ -139,8 +142,8 @@ def xqa( num_kv_heads: int, page_size: int, sinks: Optional[torch.Tensor] = None, - q_scale: float = 1.0, - kv_scale: float = 1.0, + q_scale: Union[float, torch.Tensor] = 1.0, + kv_scale: Union[float, torch.Tensor] = 1.0, sliding_win_size: int = 0, kv_layout: str = "NHD", sm_count: Optional[int] = None, @@ -188,9 +191,9 @@ def xqa( Attention sink values with shape ``[num_kv_heads, head_group_ratio]``. Data type should be torch.float32. If None, no attention sinks are used. - q_scale : float, default=1.0 + q_scale : Union[float, torch.Tensor], default=1.0 Scale factor for query tensor. - kv_scale : float, default=1.0 + kv_scale : Union[float, torch.Tensor], default=1.0 Scale factor for KV cache. sliding_win_size : int, default=0 Sliding window size for attention. If 0, no sliding window is used. @@ -319,7 +322,7 @@ def get_xqa_module_mla( ) def xqa_mla( sm_count: int, - q_scale: float, + q_scale: Union[float, torch.Tensor], output: torch.Tensor, q: torch.Tensor, k_cache: torch.Tensor, @@ -328,14 +331,15 @@ def xqa_mla( max_seq_len: int, seq_lens: torch.Tensor, batch_size: int, - kv_scale: float, + kv_scale: Union[float, torch.Tensor], semaphores: torch.Tensor, workspace_buffer: torch.Tensor, enable_pdl: bool, ) -> None: module.xqa_wrapper_mla( sm_count, - q_scale, + 1.0 if isinstance(q_scale, torch.Tensor) else q_scale, + None if isinstance(q_scale, float) else q_scale, output, q, k_cache, @@ -344,7 +348,8 @@ def xqa_mla( max_seq_len, seq_lens, batch_size, - kv_scale, + 1.0 if isinstance(kv_scale, torch.Tensor) else kv_scale, + None if isinstance(kv_scale, float) else kv_scale, semaphores, workspace_buffer, enable_pdl, @@ -355,7 +360,7 @@ def xqa_mla( ) def _fake_xqa_mla( sm_count: int, - q_scale: float, + q_scale: Union[float, torch.Tensor], output: torch.Tensor, q: torch.Tensor, k_cache: torch.Tensor, @@ -364,7 +369,7 @@ def _fake_xqa_mla( max_seq_len: int, seq_lens: torch.Tensor, batch_size: int, - kv_scale: float, + kv_scale: Union[float, torch.Tensor], semaphores: torch.Tensor, workspace_buffer: torch.Tensor, enable_pdl: bool, @@ -386,8 +391,8 @@ def xqa_mla( workspace_buffer: torch.Tensor, semaphores: torch.Tensor, page_size: int, - q_scale: float = 1.0, - kv_scale: float = 1.0, + q_scale: Union[float, torch.Tensor] = 1.0, + kv_scale: Union[float, torch.Tensor] = 1.0, sm_count: Optional[int] = None, enable_pdl: Optional[bool] = None, ) -> None: @@ -422,9 +427,9 @@ def xqa_mla( Data type should be torch.uint32. page_size : int Size of each page in the paged KV cache. Must be one of [16, 32, 64, 128]. - q_scale : float, default=1.0 + q_scale : Union[float, torch.Tensor], default=1.0 Scale factor for query tensor. - kv_scale : float, default=1.0 + kv_scale : Union[float, torch.Tensor], default=1.0 Scale factor for KV cache. sm_count : Optional[int], default=None Number of streaming multiprocessors to use. diff --git a/tests/attention/test_xqa_batch_decode.py b/tests/attention/test_xqa_batch_decode.py index a360545041..542e3194bf 100644 --- a/tests/attention/test_xqa_batch_decode.py +++ b/tests/attention/test_xqa_batch_decode.py @@ -75,8 +75,8 @@ def create_kv_cache( ): # Create separate K and V caches with specified layout (NHD or HND) max_seq_len = torch.max(seq_lens).item() - num_tokens = max_seq_len * batch_size - num_pages = (num_tokens + page_size - 1) // page_size + num_pages_per_seq = (max_seq_len + page_size - 1) // page_size + num_pages = num_pages_per_seq * batch_size ref_kv_dtype_torch = DTYPE_MAP[ref_kv_dtype] if kv_dtype != "fp8": assert kv_dtype == ref_kv_dtype, ( diff --git a/tests/attention/test_xqa_mla_batch_decode.py b/tests/attention/test_xqa_mla_batch_decode.py index 4d3abb52e1..aebf77da3d 100644 --- a/tests/attention/test_xqa_mla_batch_decode.py +++ b/tests/attention/test_xqa_mla_batch_decode.py @@ -49,8 +49,8 @@ def test_xqa_mla_batch_decode( device=device, ).to(dtype) - num_tokens = max_seq_len * batch_size - num_blocks = (num_tokens + page_size - 1) // page_size + num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size + num_blocks = num_blocks_per_seq * batch_size # Sequence lengths and block tables seq_lens = [torch.randint(1, max_seq_len, (1,)).item() for _ in range(batch_size)] From 049e8db923ed17f7bbbd89dfd8b642cb7532ed32 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 19 Nov 2025 23:16:13 -0800 Subject: [PATCH 075/130] hotfix: add 9.0a to README and installation doc (#2112) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description 9.0a was removed from installation documentation by accident, in some recent PRs. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes --- README.md | 12 ++++++++++-- docs/installation.rst | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 81b8583242..94eece5007 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,12 @@ Kernel Library for LLM Serving [![Build Status](https://ci.tlcpack.ai/job/flashinfer-ci/job/main/badge/icon)](https://ci.tlcpack.ai/job/flashinfer-ci/job/main/) [![Documentation](https://github.com/flashinfer-ai/flashinfer/actions/workflows/build-doc.yml/badge.svg)](https://github.com/flashinfer-ai/flashinfer/actions/workflows/build-doc.yml) - FlashInfer is a library and kernel generator for Large Language Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, SparseAttention, PageAttention, Sampling, and more. FlashInfer focuses on LLM serving and inference, and delivers state-of-the-art performance across diverse scenarios. Check our [v0.2 release blog](https://flashinfer.ai/2024/12/16/flashinfer-v02-release.html) for new features! The core features of FlashInfer include: + 1. **Efficient Sparse/Dense Attention Kernels**: Efficient single/batch attention for sparse(paged)/dense KV-storage on CUDA Cores and Tensor Cores (both FA2 & FA3) templates. The vector-sparse attention can achieve 90% of the bandwidth of dense kernels with same problem size. 2. **Load-Balanced Scheduling**: FlashInfer decouples `plan`/`run` stage of attention computation where we schedule the computation of variable-length inputs in `plan` stage to alleviate load-imbalance issue. 3. **Memory Efficiency**: FlashInfer offers [Cascade Attention](https://docs.flashinfer.ai/api/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper) for hierarchical KV-Cache, and implements Head-Query fusion for accelerating Grouped-Query Attention, and efficient kernels for low-precision attention and fused-RoPE attention for compressed KV-Cache. @@ -31,6 +31,7 @@ The core features of FlashInfer include: FlashInfer supports PyTorch, TVM and C++ (header-only) APIs, and can be easily integrated into existing projects. ## News + - [Mar 10, 2025] [Blog Post](https://flashinfer.ai/2025/03/10/sampling.html) Sorting-Free GPU Kernels for LLM Sampling, which explains the design of sampling kernels in FlashInfer. - [Mar 1, 2025] Checkout flashinfer's [intra-kernel profiler](https://github.com/flashinfer-ai/flashinfer/tree/main/profiler) for visualizing the timeline of each threadblock in GPU kernels. - [Dec 16, 2024] [Blog Post](https://flashinfer.ai/2024/12/16/flashinfer-v02-release.html) FlashInfer 0.2 - Efficient and Customizable Kernels for LLM Inference Serving @@ -51,11 +52,13 @@ pip install flashinfer-python ``` **Package Options:** + - **flashinfer-python**: Core package that compiles/downloads kernels on first use - **flashinfer-cubin**: Pre-compiled kernel binaries for all supported GPU architectures - **flashinfer-jit-cache**: Pre-built kernel cache for specific CUDA versions **For faster initialization and offline usage**, install the optional packages to have most kernels pre-compiled: + ```bash pip install flashinfer-python flashinfer-cubin # JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130) @@ -75,6 +78,7 @@ python -m pip install -v . ``` **For development**, install in editable mode: + ```bash python -m pip install --no-build-isolation -e . -v ``` @@ -82,6 +86,7 @@ python -m pip install --no-build-isolation -e . -v **Build optional packages:** `flashinfer-cubin`: + ```bash cd flashinfer-cubin python -m build --no-isolation --wheel @@ -89,8 +94,9 @@ python -m pip install dist/*.whl ``` `flashinfer-jit-cache` (customize `FLASHINFER_CUDA_ARCH_LIST` for your target GPUs): + ```bash -export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0a 12.0f" +export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0f" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl @@ -120,6 +126,7 @@ flashinfer show-config ``` This command displays: + - FlashInfer version and installed packages (flashinfer-python, flashinfer-cubin, flashinfer-jit-cache) - PyTorch and CUDA version information - Environment variables and artifact paths @@ -173,6 +180,7 @@ FlashInfer currently provides support for NVIDIA SM architectures 75 and higher ## Adoption We are thrilled to share that FlashInfer is being adopted by many cutting-edge projects, including but not limited to: + - [MLC-LLM](https://github.com/mlc-ai/mlc-llm) - [Punica](https://github.com/punica-ai/punica) - [SGLang](https://github.com/sgl-project/sglang) diff --git a/docs/installation.rst b/docs/installation.rst index eb2f1acf67..92bfec1651 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -92,7 +92,7 @@ You can follow the steps below to install FlashInfer from source code: .. code-block:: bash - export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0a 12.0f" + export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0a 12.0f" cd flashinfer-jit-cache python -m build --no-isolation --wheel python -m pip install dist/*.whl From 2628bebcf0b09dd80821c50f04dbbfa08ec32ca9 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 20 Nov 2025 10:33:40 -0800 Subject: [PATCH 076/130] ci/cd: add nvidia-ml-py to requirments of build-system of flashinfer-cubin (#2123) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description flashinfer-cubin package building failed because we flashinfer/utils.py relies on nvidia-ml-py which is not specified as part of build system requirements of the package. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Added a new build system dependency to support enhanced system functionality. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --- flashinfer-cubin/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer-cubin/pyproject.toml b/flashinfer-cubin/pyproject.toml index 866ff08db2..2bc526a4b3 100644 --- a/flashinfer-cubin/pyproject.toml +++ b/flashinfer-cubin/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61.0", "wheel", "requests", "filelock", "torch", "tqdm", "numpy", "apache-tvm-ffi>=0.1,<0.2"] +requires = ["setuptools>=61.0", "wheel", "requests", "filelock", "torch", "tqdm", "numpy", "apache-tvm-ffi>=0.1,<0.2", "nvidia-ml-py"] build-backend = "build_backend" backend-path = ["."] From 0aee7afd1d0db060776333f1274ebc0027d2b71e Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Thu, 20 Nov 2025 21:27:51 -0800 Subject: [PATCH 077/130] feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' (#1979) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Current PR: * Introduces an `auto` backend to `mm_fp4` that can be autotuned. **It replaces `cudnn` as the default.** * Implementation matches `bmm_fp8`'s auto backend support. * Allows `cudnn` backend to be autotuned. * Added unit test test cases for backend=auto Behavior of `auto` backend: * Examines CUDA version & cuDNN version and calls either `cutlass` or `cudnn` kernel backends. `trtllm` kernel is not considered due to a non-interchangeable interface with other backends. * `auto` backend therefore only supports inputs runnable by `cutlass` and/or `cudnn. * Non-autotuned behavior: * Constructs an ordered list of backends (cudnn, cutlass) or (cutlass, cudnn) where ordering is based on previous microbenchmark study results. * If CUDA 12 --> cutlass comes to front. * If CUDA 13 and cuDNN version < 9.15 --> cutlass comes front * If CUDA 13 and cuDNN version >= 9.15 --> cudnn comes front * If kernel is not available from a support check, it is removed from the list. * Autotune behavior: * If backend is explicitly provided --> Autotunes within the backend. Same as previous behavior, but now autotuning is supported for cudnn. * If `backend='auto'` --> Autotunes within and across backends (cudnn & cutlass) and chooses the best config of best backend. `trtllm` kernel is not considered * A lot of helper functions to `mm_fp4` were refactored to enable cross-backend autotuning. Refactoring was done to match cross-backend autotune-enabled `bmm_fp8` as a reference. ### Pytest outputs `pytest tests/gemm/test_mm_fp4.py` * SM100 (B200) CUDA 13 & cuDNN 9.15: `900 passed, 2532 skipped in 125.19s (0:02:05)` * SM100 (B200) CUDA 12 & cuDNN 9.15: `900 passed, 2532 skipped in 125.67s (0:02:05)` * SM120 (RTX 5090) CUDA 13 & cuDNN 9.15: `720 passed, 2712 skipped in 76.50s (0:01:16)` ### Example microbenchmark outputs: On SM100 (B200) CUDA 13 & cuDNN 9.15 ``` flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck [PERF] cudnn :: median time 0.018 ms; std 0.000 ms; achieved tflops 3797.932 TFLOPs/sec; achieved tb_per_sec 1.884 TB/sec [PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 3440.640 TFLOPs/sec; achieved tb_per_sec 1.707 TB/sec [PERF] trtllm :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec [PERF] auto :: median time 0.018 ms; std 0.000 ms; achieved tflops 3840.714 TFLOPs/sec; achieved tb_per_sec 1.905 TB/sec /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [PERF] cudnn :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec [PERF] auto :: median time 0.021 ms; std 0.000 ms; achieved tflops 3237.753 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec median time 0.009 ms; std 0.000 ms; achieved tflops 938.356 TFLOPs/sec; achieved tb_per_sec 2.069 TB/sec ## Autotune /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune 2025-11-11 23:43:23,715 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:25,789 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:25,790 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:26,251 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:26,251 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:26,327 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:26,327 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:26,335 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.016 ms; std 0.000 ms; achieved tflops 4129.171 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec [PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3513.845 TFLOPs/sec; achieved tb_per_sec 1.743 TB/sec [PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2613.338 TFLOPs/sec; achieved tb_per_sec 1.296 TB/sec [PERF] auto_autotune :: median time 0.016 ms; std 0.000 ms; achieved tflops 4128.768 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. 2025-11-11 23:43:37,942 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:43,116 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:43,116 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:43,124 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.154 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec [PERF] auto_autotune :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.692 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec ``` On SM100 (B200) CUDA 12 & cuDNN 9.15 ``` flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck [PERF] cudnn :: median time 0.023 ms; std 0.001 ms; achieved tflops 2975.898 TFLOPs/sec; achieved tb_per_sec 1.476 TB/sec [PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.423 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec [PERF] trtllm :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec [PERF] auto :: median time 0.020 ms; std 0.000 ms; achieved tflops 3371.229 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec (py312) root@84ef83abb1b5:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [PERF] cudnn :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec [PERF] auto :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec ## Autotune /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune 2025-11-11 23:42:43,378 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,451 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:45,451 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,910 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:45,910 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,986 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:45,986 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,993 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3190.355 TFLOPs/sec; achieved tb_per_sec 1.583 TB/sec [PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.330 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec [PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2621.440 TFLOPs/sec; achieved tb_per_sec 1.300 TB/sec [PERF] auto_autotune :: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.628 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. 2025-11-11 23:42:55,176 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:58,600 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:58,601 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:58,608 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec [PERF] auto_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec ``` On SM120 (RTX 5090) CUDA 13 & cuDNN 9.15 ``` /flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck [INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120 [PERF] cudnn :: median time 0.058 ms; std 0.000 ms; achieved tflops 1167.143 TFLOPs/sec; achieved tb_per_sec 0.579 TB/sec [PERF] cutlass :: median time 0.060 ms; std 0.000 ms; achieved tflops 1135.056 TFLOPs/sec; achieved tb_per_sec 0.563 TB/sec [PERF] auto :: median time 0.058 ms; std 0.000 ms; achieved tflops 1158.952 TFLOPs/sec; achieved tb_per_sec 0.575 TB/sec /flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120 [PERF] cudnn :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec [PERF] auto :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec ``` ## ๐Ÿ” Related Issues #1722 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * "auto" backend selection for FP4 ops to choose backend at runtime * cuDNN, CUTLASS and TRTLLM selectable as FP4 GEMM backends * CUDA/cuDNN version awareness to guide auto-backend heuristics * **Improvements** * Runtime capability checks replace static backend lists; unsupported backends are removed dynamically * Heuristic-driven auto-backend selection required for automatic mode * Expanded autotuning/warmup across backends and relaxed FP4 validation tolerance * **Tests** * Tests updated and added to exercise auto-backend scenarios and relaxed constraints โœ๏ธ Tip: You can customize this high-level summary in your review settings. --- .../routines/flashinfer_benchmark_utils.py | 12 +- benchmarks/routines/gemm.py | 109 ++- flashinfer/gemm/gemm_base.py | 720 +++++++++++------- flashinfer/utils.py | 11 +- tests/gemm/test_mm_fp4.py | 50 +- tests/utils/test_decorators.py | 22 +- 6 files changed, 543 insertions(+), 381 deletions(-) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 520029f0ec..d5f363839a 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -235,17 +235,7 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cudnn", "cublas", "cutlass"], "12.0": ["cudnn", "cublas"], }, - "mm_fp4": { - "7.5": [], - "8.0": [], - "8.6": [], - "8.9": [], - "9.0": [], - "10.0": ["cudnn", "trtllm", "cutlass"], - "10.3": ["cudnn", "trtllm", "cutlass"], - "12.0": ["cudnn", "cutlass"], - "12.1": ["cudnn", "cutlass"], - }, + # Note: mm_fp4 uses support checkers to filter backends, so it is not listed here # MOE "trtllm_fp4_block_scale_moe": { "7.5": [], diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 17336189d0..9f95f17fb4 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -131,7 +131,7 @@ def parse_gemm_args(line, parser): required=False, nargs="+", default=["cudnn"], - choices=["cudnn", "cublas", "trtllm", "cutlass"], + choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"], help="Kernel backends to test. Default: cudnn", ) parser.add_argument( @@ -790,61 +790,14 @@ def testMmFp4(args): run_refcheck = args.refcheck use_128x4_sf_layout = args.use_128x4_sf_layout use_nvfp4 = args.use_nvfp4 - autotune_supported_backends = ["cutlass", "trtllm"] + autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"] res = [] - backends = filter_backends_by_compute_capability(backends, args.routine, device) - res_dtype = dtype_str_to_torch_dtype(args.out_dtype) if res_dtype not in [torch.bfloat16, torch.float16]: raise ValueError( f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16." ) - ## Done parsing input arguments - - if "trtllm" in backends: - remove_trtllm = False - if res_dtype == torch.float16: - print("[INFO] trtllm backend does not support float16 output") - remove_trtllm = True - if remove_trtllm: - backends.remove("trtllm") - if not use_nvfp4: - print( - "[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - backends.remove("trtllm") - if "cutlass" in backends: - remove_cutlass = False - if not use_128x4_sf_layout: - print("[INFO] cutlass backend does not support use_128x4_sf_layout=False") - remove_cutlass = True - if not use_nvfp4: - print( - "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - backends.remove("cutlass") - if remove_cutlass: - backends.remove("cutlass") - if "cudnn" in backends: - remove_cudnn = False - if not use_128x4_sf_layout: - print("[INFO] cudnn backend does not support use_128x4_sf_layout=False") - remove_cudnn = True - if remove_cudnn: - backends.remove("cudnn") - if getattr(args, "autotune", False): - backends_to_remove = [] - for cur_backend in backends: - if cur_backend not in autotune_supported_backends: - print(f"[INFO] {cur_backend} backend does not support autotune") - backends_to_remove.append(cur_backend) - for cur_backend in backends_to_remove: - backends.remove(cur_backend) - - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return input = torch.randn([m, k], device=device, dtype=torch.bfloat16) mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16) @@ -886,11 +839,22 @@ def testMmFp4(args): print(f"[VVERBOSE] {mat2_fp4.dtype = }") alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None - # res = torch.empty([m, n], device="cuda", dtype=res_dtype) + # Completed preparing inputs. Now programmatically filter backends + block_size = 16 if use_nvfp4 else 32 + backends_to_remove = [] - def run_backend(backend): - if backend in ["cudnn", "trtllm", "cutlass"]: - return flashinfer.gemm.mm_fp4( + for backend in backends: + # Skip autotune check for now (handled separately below) + if ( + getattr(args, "autotune", False) + and backend not in autotune_supported_backends + ): + print(f"[INFO] {backend} backend does not support autotune") + backends_to_remove.append(backend) + continue + + try: + flashinfer.gemm.mm_fp4( a=input_fp4, b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T, a_descale=input_inv_s, @@ -904,6 +868,34 @@ def run_backend(backend): backend=backend, use_nvfp4=use_nvfp4, ) + except Exception as e: + print( + f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" + ) + backends_to_remove.append(backend) + + # Remove unsupported backends + for backend in backends_to_remove: + backends.remove(backend) + + if len(backends) == 0: + print("[ERROR] No backends passed validation. Exiting.") + return + + def run_backend(backend): + if backend in ["cudnn", "trtllm", "cutlass", "auto"]: + return flashinfer.gemm.mm_fp4( + a=input_fp4, + b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T, + a_descale=input_inv_s, + b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T, + alpha=alpha, + out_dtype=res_dtype, + block_size=block_size, + use_8x4_sf_layout=not use_128x4_sf_layout, + backend=backend, + use_nvfp4=use_nvfp4, + ) else: raise ValueError(f"Unsupported backend: {backend}") @@ -917,12 +909,11 @@ def run_backend(backend): args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 ) for cur_backend in backends: - if cur_backend in autotune_supported_backends: - if args.verbose >= 1: - print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters") - with autotune(True): - for _ in range(warmup_iters): - run_backend(cur_backend) + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters") + with autotune(True): + for _ in range(warmup_iters): + run_backend(cur_backend) # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ac0fbab4a0..589c651aca 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -54,6 +54,7 @@ from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module from ..jit.gemm import gen_deepgemm_sm100_module +from ..jit.cpp_ext import get_cuda_version CUDNN_AVAILABLE = False @@ -406,87 +407,50 @@ def fp8_gemm_sm100( def _create_cutlass_fp4_gemm_module(module, op_name: str, tuner_name: str): """Helper function to create cutlass FP4 GEMM module.""" - class CutlassFp4GemmRunner(TunableRunner): - def __init__(self): - self._fp4_gemm_runner = module.fp4_gemm + def cutlass_fp4_gemm_runner(): + class CutlassFp4GemmRunner(TunableRunner): + def __init__(self): + self._fp4_gemm_runner = module.fp4_gemm - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - return list(range(module.fp4_gemm_tactic_num())) - - def forward( - self, - inputs: List[torch.Tensor], - tactic: int = -1, - do_preparation: bool = False, - **kwargs, - ): - a, b, a_descale, b_descale, alpha, out, workspace_buffer = inputs - module.fp4_gemm( - a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic - ) - return out - - @register_custom_op( - op_name, - mutates_args=(""), - ) - def cutlass_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - workspace_buffer: torch.Tensor, - ): - tuner = AutoTuner.get() - - a_tensor_index = 0 - a_scale_tensor_index = 2 - out_tensor_index = 5 - - def pad_up(x, y): - return ((x + y - 1) // y) * y - - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up(shapes[a_tensor_index][0], 128), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] - ), - ), - ) - - fp4_runner = CutlassFp4GemmRunner() + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + return list(range(module.fp4_gemm_tactic_num())) - inputs = [a, b, a_descale, b_descale, alpha, out, workspace_buffer] - _, tactic = tuner.choose_one( - tuner_name, - [fp4_runner], - tuning_config, - inputs, - ) + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ): + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: + a_descale = a_descale.view(torch.uint8) + if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: + b_descale = b_descale.view(torch.uint8) + module.fp4_gemm( + a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer, tactic + ) + return out - fp4_runner(inputs=inputs, tactic=tactic) + return CutlassFp4GemmRunner() return SimpleNamespace( - cutlass_fp4_gemm=cutlass_fp4_gemm, + cutlass_fp4_gemm_runner=cutlass_fp4_gemm_runner, ) @@ -508,6 +472,17 @@ def get_gemm_sm120_module_cutlass_fp4(): ) +def get_cutlass_fp4_gemm_module( + sm_major: int, +): + if sm_major in [10, 11]: + return get_gemm_sm100_module_cutlass_fp4() + elif sm_major == 12: + return get_gemm_sm120_module_cutlass_fp4() + else: + raise ValueError(f"Unsupported SM major version: {sm_major}") + + @functools.cache def get_tgv_gemm_sm10x_module( dtype: torch.dtype = torch.bfloat16, use_sm_100f: bool = False @@ -1139,7 +1114,6 @@ def _check_cudnn_fp4_availability(): def _is_cublas_fp4_available_in_cudnn(): """Check if cuBLAS backend for FP4 GEMM is available in cuDNN.""" - _check_cudnn_availability() # Check cuDNN backend version for FP4 support (requires cudnn_version == 9.11.1 or cudnn_version >= 9.13) backend_version = cudnn.backend_version() @@ -1191,7 +1165,6 @@ def create_cudnn_execution_plans_fp4_gemm( alpha_is_not_none, use_nvfp4, ): - _check_cudnn_availability() stream = torch.cuda.current_stream(device) with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _): scale_type = cudnn.data_type.FP8_E4M3 if use_nvfp4 else cudnn.data_type.FP8_E8M0 @@ -1292,7 +1265,9 @@ def build_plans_cudnn_fp4_gemm_graph( device, alpha, use_nvfp4, + tactic: int = -1, ): + # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement graph = create_cudnn_execution_plans_fp4_gemm( a_shape, a_stride, @@ -1311,7 +1286,10 @@ def build_plans_cudnn_fp4_gemm_graph( ) graph.check_support() - graph.build_plans() + if tactic != -1: + graph.build_plan_at_index(tactic) + else: + graph.build_plans() return graph @@ -1324,6 +1302,7 @@ def execute_cudnn_gemm_fp4_graph( alpha, c_final, workspace_buffer, + tactic: int = -1, ): variant_pack = { UIDs.A_UID.value: a.view(get_native_fp4_dtype()), @@ -1343,7 +1322,12 @@ def execute_cudnn_gemm_fp4_graph( stream = torch.cuda.current_stream(a.device) - graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream)) + if tactic == -1: + graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream)) + else: + graph.execute_plan_at_index( + variant_pack, workspace_buffer, tactic, handle=_get_cudnn_handle(stream) + ) @functools.cache @@ -1677,7 +1661,54 @@ def mm_fp8( return out -def _check_mm_fp4_problem_size( +def _get_cudnn_fp4_gemm_graph( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_nvfp4: bool = True, + tactic: int = -1, +): + # the fp4 cudnn graph will be shared for both mm and bmm, so + # here we need to get the 3d shape and stride including the + # batch dimension for both input and block scale tensors. + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + expanded_a_descale_shape, expanded_a_descale_stride = ( + _expand_block_scale_tensor_shape(a_descale, batch) + ) + expanded_b_descale_shape, expanded_b_descale_stride = ( + _expand_block_scale_tensor_shape(b_descale, batch) + ) + + # build the fp4 cudnn graph + # Constructed graph is cached, via @functools.cache decorator. + graph = build_plans_cudnn_fp4_gemm_graph( + real_a_shape, + real_a_stride, + real_b_shape, + real_b_stride, + expanded_a_descale_shape, + expanded_a_descale_stride, + expanded_b_descale_shape, + expanded_b_descale_stride, + cudnn.data_type.FP4_E2M1, + _torch_data_type_to_cudnn_data_type(out_dtype), + block_size, + a.device, + alpha is not None, + use_nvfp4, + tactic=tactic, + ) + return graph + + +def _cudnn_gemm_fp4( a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, @@ -1686,8 +1717,114 @@ def _check_mm_fp4_problem_size( out_dtype: torch.dtype = torch.bfloat16, out: Optional[torch.Tensor] = None, block_size: int = 16, - use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + use_nvfp4: bool = True, + workspace_buffer: torch.Tensor = None, + tactic: int = -1, +): + # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement + graph = _get_cudnn_fp4_gemm_graph( + a=a, + b=b, + a_descale=a_descale, + b_descale=b_descale, + alpha=alpha, + out_dtype=out_dtype, + out=out, + block_size=block_size, + use_nvfp4=use_nvfp4, + tactic=tactic, + ) + # execute the fp4 cudnn graph + execute_cudnn_gemm_fp4_graph( + graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic=tactic + ) + + +def _cudnn_gemm_fp4_runner(): + class CudnnFp4GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic + ( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ) = inputs + + # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement + graph = _get_cudnn_fp4_gemm_graph( + a=a, + b=b, + a_descale=a_descale, + b_descale=b_descale, + alpha=alpha, + out_dtype=out_dtype, + out=out, + block_size=block_size, + use_nvfp4=use_nvfp4, + tactic=-1, + ) + + num_plans = graph.get_execution_plan_count() + return list(range(num_plans)) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + ( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ) = inputs + _cudnn_gemm_fp4( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + tactic=tactic, + ) + + return CudnnFp4GemmRunner() + + +def _check_mm_fp4_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, # unused + block_size: int = 16, + use_8x4_sf_layout: bool = False, # unused + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): # Generic checks @@ -1725,11 +1862,6 @@ def _check_mm_fp4_problem_size( f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations." ) - if backend != "trtllm" and use_8x4_sf_layout: - raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") - if backend != "cudnn" and not use_nvfp4: - raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.") - if use_nvfp4 and block_size != 16: raise ValueError("nvfp4 only supports block_size = 16.") if not use_nvfp4 and block_size != 32: @@ -1746,12 +1878,14 @@ def _cudnn_gemm_fp4_requirement( b_descale: torch.Tensor, alpha: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, # unused block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): + if use_8x4_sf_layout: + raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") if ( not use_nvfp4 and _match_sm_version(a.device, ["120"]) @@ -1774,7 +1908,8 @@ def _cudnn_gemm_fp4_requirement( _expand_block_scale_tensor_shape(b_descale, batch) ) - # build the fp4 cudnn graph + # build the fp4 cudnn graph. This graph will be cached & reused in mm_fp4() + # because the graph is constructed with @functools.cache decorator graph = create_cudnn_execution_plans_fp4_gemm( real_a_shape, real_a_stride, @@ -1798,18 +1933,20 @@ def _cudnn_gemm_fp4_requirement( @supported_compute_capability([100, 103]) def _trtllm_gemm_fp4_requirement( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: Optional[torch.Tensor] = None, + a: torch.Tensor, # unused + b: torch.Tensor, # unused + a_descale: torch.Tensor, # unused + b_descale: torch.Tensor, # unused + alpha: Optional[torch.Tensor] = None, # unused out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, - block_size: int = 16, - use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + out: Optional[torch.Tensor] = None, # unused + block_size: int = 16, # unused + use_8x4_sf_layout: bool = False, # unused + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): + if not use_nvfp4: + raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.") if out_dtype != torch.bfloat16: raise ValueError( f"Unsupported output dtype: {out_dtype}. " @@ -1820,6 +1957,27 @@ def _trtllm_gemm_fp4_requirement( @supported_compute_capability([100, 103, 110, 120, 121]) def _cutlass_gemm_fp4_requirement( + a: torch.Tensor, # unused + b: torch.Tensor, # unused + a_descale: torch.Tensor, # unused + b_descale: torch.Tensor, # unused + alpha: Optional[torch.Tensor] = None, # unused + out_dtype: torch.dtype = torch.bfloat16, # unused + out: Optional[torch.Tensor] = None, # unused + block_size: int = 16, # unused + use_8x4_sf_layout: bool = False, + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused + use_nvfp4: bool = True, +): + if use_8x4_sf_layout: + raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") + if not use_nvfp4: + raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.") + return True + + +def _heuristic_func_mm_fp4( + suitable_backends: List[str], a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, @@ -1829,19 +1987,42 @@ def _cutlass_gemm_fp4_requirement( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn", use_nvfp4: bool = True, ): - return True + r""" + Heuristic function for mm_fp4 backend selection. Routes to either cudnn or cutlass. + Note: trtllm is not considered in the backend selection because it requires a specific + input quantization (swizzling/shuffling) that differs from the preparation used + for cudnn and cutlass backends. + + Logic for which comes first: + - If cuda version is 12 - use cutlass. + - If cuda version is 13 and cudnn version is less than 9.15 - use cutlass. + - If cuda version is 13 and cudnn version is 9.15 or greater - use cudnn. + + """ + cuda_major = get_cuda_version().major + # If cuda version is 13 or greater: + # cudnn is more performant if cudnn version is 9.15 or greater. + if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91500: + candidate_backends = ("cudnn", "cutlass") + # Otherwise, prioritize cutlass + else: + candidate_backends = ("cutlass", "cudnn") + + # Filter and return only supported backends + return [c for c in candidate_backends if c in suitable_backends] @backend_requirement( { - "cudnn": _cudnn_gemm_fp4_requirement, # Each backend has its own requirement function + "cudnn": _cudnn_gemm_fp4_requirement, "trtllm": _trtllm_gemm_fp4_requirement, "cutlass": _cutlass_gemm_fp4_requirement, }, - common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends + common_check=_check_mm_fp4_problem_size, + heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends ) def mm_fp4( a: torch.Tensor, @@ -1853,7 +2034,7 @@ def mm_fp4( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ) -> torch.Tensor: r"""MM FP4 @@ -1887,8 +2068,8 @@ def mm_fp4( use_8x4_sf_layout: bool Whether to use 8x4 scale factor layout or 128x4 scale factor layout, defaults to False. - backend: Literal["cudnn", "trtllm", "cutlass"] - Backend to use, defaults to "cudnn". + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] + Backend to use, defaults to "auto", which automatically selects the best backend between cudnn and cutlass. use_nvfp4: bool Whether to use nvfp4 quantization or mxfp4 quantization, defaults to False. @@ -1930,70 +2111,78 @@ def mm_fp4( "mm_fp4_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) - if backend == "cudnn": - # the fp4 cudnn graph will be shared for both mm and bmm, so - # here we need to get the 3d shape and stride including the - # batch dimension for both input and block scale tensors. - real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) - real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) - batch = real_a_shape[0] - expanded_a_descale_shape, expanded_a_descale_stride = ( - _expand_block_scale_tensor_shape(a_descale, batch) - ) - expanded_b_descale_shape, expanded_b_descale_stride = ( - _expand_block_scale_tensor_shape(b_descale, batch) - ) + # Auto-select the best backend + if backend == "auto": + backends = mm_fp4.suitable_auto_backends + else: + backends = [backend] - # build the fp4 cudnn graph - graph = build_plans_cudnn_fp4_gemm_graph( - real_a_shape, - real_a_stride, - real_b_shape, - real_b_stride, - expanded_a_descale_shape, - expanded_a_descale_stride, - expanded_b_descale_shape, - expanded_b_descale_stride, - cudnn.data_type.FP4_E2M1, - _torch_data_type_to_cudnn_data_type(out_dtype), - block_size, - a.device, - alpha is not None, - use_nvfp4, - ) + # At this point, backends contains a supported backend if specified, or all supported backends if backend='auto'. + # Lazy initialization of runners to avoid overhead of creating a new runner that will not be used + major, _ = get_compute_capability(a.device) - # execute the fp4 cudnn graph - execute_cudnn_gemm_fp4_graph( - graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer - ) - elif backend == "trtllm": - get_trtllm_fp4_gemm_module().trtllm_fp4_gemm( - a, - b.T, - a_descale, - b_descale.T, - alpha, - out, - use_8x4_sf_layout=use_8x4_sf_layout, - workspace_buffer=workspace_buffer, - ) - elif backend == "cutlass": - # cutlass require uint8 scale when a/b is fp4 packed uint8. - if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: - a_descale = a_descale.view(torch.uint8) - if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: - b_descale = b_descale.view(torch.uint8) - - # Dispatch to the correct module based on device architecture - major, _ = get_compute_capability(a.device) - if major == 12: - gemm_module = get_gemm_sm120_module_cutlass_fp4() - else: - gemm_module = get_gemm_sm100_module_cutlass_fp4() + backend_to_runner_factory = { + "cudnn": lambda: _cudnn_gemm_fp4_runner(), + "trtllm": lambda: get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner( + use_8x4_sf_layout + ), + "cutlass": lambda: get_cutlass_fp4_gemm_module(major).cutlass_fp4_gemm_runner(), + } + runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends] - gemm_module.cutlass_fp4_gemm( - a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer - ) + # Now we have a list of runners for desired & supported backends. + tuner = AutoTuner.get() + + a_tensor_index = 0 + a_scale_tensor_index = 2 + out_tensor_index = 6 + + def pad_up(x, y): + return ((x + y - 1) // y) * y + + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (a_tensor_index,), + (0,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + a_scale_tensor_index, + 0, + lambda shapes: pad_up( + shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 + ), + ), + ConstraintSpec( + out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] + ), + ), + ) + + inputs = [ + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ] + runner, tactic = tuner.choose_one( + "fp4_gemm", + runners, + tuning_config, + inputs, + ) + + runner(inputs=inputs, tactic=tactic) return out @@ -2355,139 +2544,82 @@ def get_trtllm_fp4_gemm_module(): op = mod.build_and_load() setup_cubin_loader(mod.get_library_path()) - class TrtllmFp4GemmRunner(TunableRunner): - def __init__(self, use_8x4_sf_layout: bool = True): - self._fp4_gemm_runner = op.trtllm_gemm - self._use_8x4_sf_layout = use_8x4_sf_layout + def trtllm_fp4_gemm_runner(use_8x4_sf_layout: bool = True): + class TrtllmFp4GemmRunner(TunableRunner): + def __init__(self, use_8x4_sf_layout: bool = True): + self._fp4_gemm_runner = op.trtllm_gemm + self._use_8x4_sf_layout = use_8x4_sf_layout - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - a_tensor_index = 1 - b_tensor_index = 2 - - a = profile.get_opt_shapes()[a_tensor_index] - b = profile.get_opt_shapes()[b_tensor_index] - m = a[0] - n = b[0] - k = a[1] * 2 - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs - type_e2m1 = 0 - type_bf16 = 2 - return list( - op.trtllm_gemm_tactics( - m, n, k, type_e2m1, type_bf16, self._use_8x4_sf_layout + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + a_tensor_index = 1 + b_tensor_index = 2 + + a = profile.get_opt_shapes()[a_tensor_index] + b = profile.get_opt_shapes()[b_tensor_index] + m = a[0] + n = b[0] + k = a[1] * 2 + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + type_e2m1 = 0 + type_bf16 = 2 + return list( + op.trtllm_gemm_tactics( + m, n, k, type_e2m1, type_bf16, self._use_8x4_sf_layout + ) ) - ) - def forward( - self, - inputs: List[torch.Tensor], - tactic: int = -1, - do_preparation: bool = False, - **kwargs, - ): - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs - op.trtllm_gemm( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - self._use_8x4_sf_layout, - tactic, - ) - return out - - @register_custom_op( - "flashinfer::trtllm_fp4_gemm", - mutates_args=(""), - ) - def trtllm_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - use_8x4_sf_layout: bool, - workspace_buffer: torch.Tensor, - ): - tuner = AutoTuner.get() - - a_tensor_index = 1 - a_scale_tensor_index = 3 - out_tensor_index = 6 - - def pad_up(x, y): - return ((x + y - 1) // y) * y - - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up( - shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 - ), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] - ), - ), - ) - - fp4_runner = TrtllmFp4GemmRunner(use_8x4_sf_layout) - - inputs = [ - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ] - _, tactic = tuner.choose_one( - "trtllm_fp4_gemm_8x4" if use_8x4_sf_layout else "trtllm_fp4_gemm_128x4", - [fp4_runner], - tuning_config, - inputs, - ) + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ): + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + self._fp4_gemm_runner( + workspace_buffer, + a, + b.T, + a_descale, + b_descale.T, + alpha, + out, + self._use_8x4_sf_layout, + tactic, + ) + return out - fp4_runner(inputs=inputs, tactic=tactic) + return TrtllmFp4GemmRunner(use_8x4_sf_layout) # Register the module return SimpleNamespace( - trtllm_fp4_gemm=trtllm_fp4_gemm, + trtllm_fp4_gemm_runner=trtllm_fp4_gemm_runner, ) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 76689bab84..e323125efa 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -921,6 +921,13 @@ def backend_requirement( backends. Should accept the same arguments as the decorated function and return True if requirements are met, False otherwise. In the case where the kernel function does not have any specific backends, this can be decorated with @supported_compute_capability to specify the function's supported compute capabilities. + heuristic_func : callable, optional + A function that performs heuristic backend selection when backend is "auto". + Must be provided if backend is "auto". Does not do anything if backend is not "auto". + Should accept the same arguments as the decorated function. + Should return an ordered list of runnable backends with the most preferred backend first. + When decorated function is not autotuned, the first backend in the heuristic list will be run. + When decorated function is autotuned, the backends in the heuristic list will be autotuned over to find the best backend. Returns ------- @@ -1076,8 +1083,8 @@ def suitable_auto_backends(cc, *args, **kwargs): except ValueError: continue # If a heuristic function is provided, filter the suitable backends based on the heuristic function - if heuristic_func is not None: - suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) + assert heuristic_func is not None, "Heuristic function must be provided" + suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) if not suitable_backends: return False wrapper.suitable_auto_backends = suitable_backends diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index 9d7a7abbbd..cc85f6126a 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -12,16 +12,7 @@ from flashinfer.gemm.gemm_base import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR -# TODO: Consdier splitting this function up for the various backends -@pytest.mark.parametrize("m", [1, 48, 128, 256, 512]) -@pytest.mark.parametrize("n", [128, 256, 512]) -@pytest.mark.parametrize("k", [128, 256, 512]) -@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"]) -@pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) -@pytest.mark.parametrize("auto_tuning", [False, True]) -@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) -def test_mm_fp4( +def _test_mm_fp4( m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type ): use_nvfp4 = fp4_type == "nvfp4" @@ -40,10 +31,8 @@ def test_mm_fp4( pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.") if not use_128x4_sf_layout and backend != "trtllm": pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False") - if auto_tuning and backend == "cudnn": - pytest.skip("Skipping test for cudnn fp4 with auto_tuning=True") - if not use_nvfp4 and backend != "cudnn": - pytest.skip("mx_fp4 is only supported for cudnn backend") + if not use_nvfp4 and backend not in ["cudnn", "auto"]: + pytest.skip("mx_fp4 is only supported for cudnn and auto backends") input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) @@ -105,5 +94,38 @@ def test_mm_fp4( pytest.fail(str(e)) +# TODO: Consdier splitting this function up for the various backends +@pytest.mark.parametrize("m", [1, 48, 128, 256, 512]) +@pytest.mark.parametrize("n", [128, 256, 512]) +@pytest.mark.parametrize("k", [128, 256, 512]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"]) +@pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) +@pytest.mark.parametrize("auto_tuning", [False, True]) +@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) +def test_mm_fp4( + m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type +): + # Non-auto backends + _test_mm_fp4( + m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type + ) + + +# Split tests for checking auto functionality +@pytest.mark.parametrize("m", [1, 48, 256, 512]) +@pytest.mark.parametrize("n", [256, 512]) +@pytest.mark.parametrize("k", [256, 512]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("use_128x4_sf_layout", [True]) +@pytest.mark.parametrize("auto_tuning", [False, True]) +@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) +def test_mm_fp4_backend_auto( + m, n, k, res_dtype, use_128x4_sf_layout, auto_tuning, fp4_type +): + # Some test cases for auto backend. + _test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index ebbda781fb..f8659c2e44 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -344,7 +344,27 @@ def _cutlass_check(x, backend): def _cudnn_check(x, backend): return x.shape[0] > 5 - @backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check}) + # When using an auto backend, some heuristic function must exist + def _heuristic_func(suitable_backends, x, backend): + candidate_backends = None + if x.shape[0] > 5: + candidate_backends = ["cudnn", "cutlass"] + else: + candidate_backends = ["cutlass", "cudnn"] + + heuristic_backends = [] + for backend in candidate_backends: + if backend in suitable_backends: + heuristic_backends.append(backend) + return heuristic_backends + + @backend_requirement( + backend_checks={ + "cutlass": _cutlass_check, + "cudnn": _cudnn_check, + }, + heuristic_func=_heuristic_func, + ) def my_kernel(x, backend="auto"): backends = my_kernel.suitable_auto_backends if x.shape[0] > 5: From 7128c7bc9e08e50e488f296ebd9aa1d6c5600d13 Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Fri, 21 Nov 2025 14:56:54 -0800 Subject: [PATCH 078/130] fix: Fix bench_mm_fp8.py (#2129) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description `bench_mm_fp8.py` was not functioning because `res` was being provided as a fourth positional argument when it should be given as out=res ``` def mm_fp8( a: torch.Tensor, b: torch.Tensor, alpha: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, out: Optional[torch.Tensor] = None, backend: Literal["trtllm_low_latency"] = "trtllm_low_latency", ): ``` Output after fix: ``` flashinfer$ python3 benchmarks/bench_mm_fp8.py 2025-11-21 09:38:10,084 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-21 09:38:10,328 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends mm_fp8 m=1 n=2560 k=16384 in_dtype=torch.float8_e4m3fn out_dtype=torch.bfloat16: 6.36 TFLOPs/s over 0.013199 ms, 3.18 TB/s 2025-11-21 09:38:10,551 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-21 09:38:10,573 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends mm_fp8 m=1 n=2560 k=32768 in_dtype=torch.float8_e4m3fn out_dtype=torch.bfloat16: 7.28 TFLOPs/s over 0.023040 ms, 3.64 TB/s 2025-11-21 09:38:10,671 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-21 09:38:10,692 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends mm_fp8 m=1 n=5120 k=16384 in_dtype=torch.float8_e4m3fn out_dtype=torch.bfloat16: 8.31 TFLOPs/s over 0.020191 ms, 4.16 TB/s 2025-11-21 09:38:10,789 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-21 09:38:10,813 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends mm_fp8 m=1 n=5120 k=32768 in_dtype=torch.float8_e4m3fn out_dtype=torch.bfloat16: 9.40 TFLOPs/s over 0.035696 ms, 4.70 TB/s 2025-11-21 09:38:10,918 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-21 09:38:10,941 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends mm_fp8 m=1 n=8192 k=16384 in_dtype=torch.float8_e4m3fn out_dtype=torch.bfloat16: 9.16 TFLOPs/s over 0.029312 ms, 4.58 TB/s 2025-11-21 09:38:11,045 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-21 09:38:11,072 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends mm_fp8 m=1 n=8192 k=32768 in_dtype=torch.float8_e4m3fn out_dtype=torch.bfloat16: 10.14 TFLOPs/s over 0.052959 ms, 5.07 TB/s ... ``` Also changed measurement methodology slightly to use cupti. Previous methodology inflated performance numbers due to not flushing L2 cache or using a rotating buffer to start with a cold cash. Benchmark should produce much accurate performance numbers due to L2 flush with `enable_cupti=True` ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Adjusted benchmark timing settings to shorten warm-up and measurement durations for faster test runs. * Enabled CUPTI profiling for more detailed GPU performance metrics in FP8 matrix-multiplication benchmarks. * Made non-functional parameter/argument updates and clarifying comments; no changes to core computation logic. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --- benchmarks/bench_mm_fp8.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_mm_fp8.py b/benchmarks/bench_mm_fp8.py index a4df76ebd9..7661d5a57e 100644 --- a/benchmarks/bench_mm_fp8.py +++ b/benchmarks/bench_mm_fp8.py @@ -67,11 +67,12 @@ def bench_mm_fp8(m, n, k, in_dtype, out_dtype): input_fp8, prepared_weights, global_scale, - res, + out=res, ), - dry_run_time_ms=500, - repeat_time_ms=2500, + dry_run_time_ms=25, + repeat_time_ms=100, # 100ms should be enough for low latency kernels that run within 100 usec use_cuda_graph=True, + enable_cupti=True, ) ms = np.median(measurements) tflops_per_second = 2 * m * n * k * 1e-9 / ms From 5acb57bd8b3cbb23d010ed8d306c19c4c6c5985d Mon Sep 17 00:00:00 2001 From: "Brian K. Ryu" Date: Fri, 21 Nov 2025 23:24:14 -0800 Subject: [PATCH 079/130] feat: Enable API Logging for Better Debugging POC (#2108) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description tl; dr: Current PR adds a logging system for input/output tracking to aid debugging FlashInfer APIs via a `@flashinfer_api` decorator. **This PR does not label `@flashinfer_api` to every FlashInfer API -- many operations are missing labels. Further labeling is left for subsequent work.** This PR introduces a production-ready API logging infrastructure that tracks function calls, arguments, and return values via a simple one-line decorator. Any function can be decorated with the decorator to track the input/output values in the API logger. Key Features: * Logging level controlled by `FLASHINFER_LOGLEVEL` * Log destination set by `FLASHINFER_LOGDEST`; defaults to `stdout` * Zero overhead when disabled (level 0 returns original function) as seen from `benchmarks/bench_logging_overhead.py` Example usage ``` export FLASHINFER_LOGLEVEL=1 export FLASHINFER_LOGDEST="./flashinfer_api.log" python3 benchmarks/flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 1 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 ``` produces log ``` ================================================================================ [2025-11-20 17:51:18] FlashInfer API Logging - System Information ================================================================================ FlashInfer version: 0.5.2 CUDA toolkit version: 13.0 cuDNN version: 91600 Number of GPUs: 1 GPU 0: NVIDIA B200 Compute capability: 10.0 (SM100) PyTorch version: 2.9.0+cu130 ================================================================================ [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.plan [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.plan [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.plan [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.run [2025-11-20 17:51:19] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.run ... ``` `export FLASHINFER_LOGLEVEL=3` produces: ``` (System Info same as above) ================================================================================ [2025-11-20 17:51:58] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ -------------------------------------------------------------------------------- Positional input arguments: arg[0]: arg[1]: Tensor( shape=(134217728,) stride=(1,) dtype=torch.int8 device=cuda:0 requires_grad=False is_contiguous=True ) arg[2]: 'HND' Keyword input arguments: use_cuda_graph= True use_tensor_cores= False paged_kv_indptr_buffer= Tensor( shape=(2,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True ) paged_kv_indices_buffer= Tensor( shape=(6,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True ) paged_kv_last_page_len_buffer= Tensor( shape=(1,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True ) backend= 'fa2' Default parameters (not explicitly provided): jit_args= [DEFAULT] None Output value: None ================================================================================ ... ``` `export FLASHINFER_LOGLEVEL=5` produces: ``` (System Info same as above) ================================================================================ [2025-11-20 17:52:23] FlashInfer API Call: BatchDecodeWithPagedKVCacheWrapper.__init__ -------------------------------------------------------------------------------- Positional input arguments: arg[0]: arg[1]: Tensor( shape=(134217728,) stride=(1,) dtype=torch.int8 device=cuda:0 requires_grad=False is_contiguous=True min=0 max=0 mean=0.000000 ) arg[2]: 'HND' Keyword input arguments: use_cuda_graph= True use_tensor_cores= False paged_kv_indptr_buffer= Tensor( shape=(2,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True min=0 max=6 mean=3.000000 ) paged_kv_indices_buffer= Tensor( shape=(6,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True min=0 max=5 mean=2.500000 ) paged_kv_last_page_len_buffer= Tensor( shape=(1,) stride=(1,) dtype=torch.int32 device=cuda:0 requires_grad=False is_contiguous=True min=4 max=4 mean=4.000000 ) backend= 'fa2' Default parameters (not explicitly provided): jit_args= [DEFAULT] None Output value: None ================================================================================ ... ``` ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit ## Release Notes * **New Features** * Added API logging feature configurable via environment variables (FLASHINFER_LOGLEVEL for level control, FLASHINFER_LOGDEST for destination) * Supports five verbosity levels with function names, inputs, outputs, metadata, and tensor statistics * Zero-overhead operation when disabled * **Tests** * Added comprehensive logging test suite * **Documentation** * Added logging configuration and usage documentation โœ๏ธ Tip: You can customize this high-level summary in your review settings. --- README.md | 14 + benchmarks/bench_logging_overhead.py | 333 +++++++++++++++ docs/index.rst | 1 + docs/logging.rst | 118 ++++++ flashinfer/api_logging.py | 565 +++++++++++++++++++++++++ flashinfer/cudnn/decode.py | 2 + flashinfer/cudnn/prefill.py | 2 + flashinfer/decode.py | 10 + flashinfer/fused_moe/core.py | 7 + flashinfer/gemm/gemm_base.py | 12 + flashinfer/mla.py | 4 + flashinfer/prefill.py | 11 + tests/utils/test_logging.py | 588 +++++++++++++++++++++++++++ 13 files changed, 1667 insertions(+) create mode 100644 benchmarks/bench_logging_overhead.py create mode 100644 docs/logging.rst create mode 100644 flashinfer/api_logging.py create mode 100644 tests/utils/test_logging.py diff --git a/README.md b/README.md index 94eece5007..cd5c7e1e58 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,20 @@ o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill att Check out [documentation](https://docs.flashinfer.ai/) for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels. +## API Logging + +FlashInfer provides comprehensive API logging for debugging. Enable it using environment variables: + +```bash +# Enable logging (levels: 0=off (default), 1=basic, 3=detailed, 5=statistics) +export FLASHINFER_LOGLEVEL=3 + +# Set log destination (stdout (default), stderr, or file path) +export FLASHINFER_LOGDEST=stdout +``` + +For detailed information about logging levels, configuration, and advanced features, see [LOGGING.md](LOGGING.md). + ## Custom Attention Variants Starting from FlashInfer v0.2, users can customize their own attention variants with additional parameters. For more details, refer to our [JIT examples](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_jit_example.py). diff --git a/benchmarks/bench_logging_overhead.py b/benchmarks/bench_logging_overhead.py new file mode 100644 index 0000000000..e67edcfa45 --- /dev/null +++ b/benchmarks/bench_logging_overhead.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +""" +Benchmark script to measure the overhead of API logging at different levels. + +This script creates decorated and undecorated versions of a test function +(torch.matmul) and compares their performance to accurately measure logging overhead. + +Usage: + # Set the logging level before running + export FLASHINFER_LOGLEVEL=3 + python bench_logging_overhead.py + + # Or run with different levels + FLASHINFER_LOGLEVEL=0 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=1 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=3 python bench_logging_overhead.py + FLASHINFER_LOGLEVEL=5 python bench_logging_overhead.py + + # Or use the helper script to run all levels + bash benchmark_all_levels.sh +""" + +import os +import sys +import time +import torch +import numpy as np +from typing import List, Tuple + +# Get logging level BEFORE importing flashinfer +LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) +LOG_DEST = os.environ.get("FLASHINFER_LOGDEST", "/tmp/flashinfer_benchmark_log.txt") + +# Import the decorator +from flashinfer.api_logging import flashinfer_api + + +# Create two versions of a test function: +# 1. Undecorated (baseline) +# 2. Decorated (with logging) +def test_matmul_undecorated(A, B): + return torch.matmul(A, B) + + +@flashinfer_api +def test_matmul_decorated(A, B): + return torch.matmul(A, B) + + +class BenchmarkResults: + """Store and display benchmark results.""" + + def __init__(self): + self.undecorated_times = [] + self.decorated_times = [] + + def set_undecorated(self, times: List[float]): + """Set benchmark results for undecorated function.""" + self.undecorated_times = times + + def set_decorated(self, times: List[float]): + """Set benchmark results for decorated function.""" + self.decorated_times = times + + def print_summary(self, logging_level: int): + """Print a summary of benchmark results.""" + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + + undecorated_mean = np.mean(self.undecorated_times) + undecorated_std = np.std(self.undecorated_times) + + decorated_mean = np.mean(self.decorated_times) + decorated_std = np.std(self.decorated_times) + + overhead_abs = (decorated_mean - undecorated_mean) * 1000 # ms + overhead_pct = ( + ((decorated_mean - undecorated_mean) / undecorated_mean * 100) + if undecorated_mean > 0 + else 0 + ) + + print( + f"\n{'Version':<20} {'Mean (ms)':<12} {'Std (ms)':<12} {'Median (ms)':<12}" + ) + print("-" * 80) + print( + f"{'Undecorated':<20} {undecorated_mean * 1000:<12.4f} {undecorated_std * 1000:<12.4f} {np.median(self.undecorated_times) * 1000:<12.4f}" + ) + print( + f"{'Decorated':<20} {decorated_mean * 1000:<12.4f} {decorated_std * 1000:<12.4f} {np.median(self.decorated_times) * 1000:<12.4f}" + ) + + print("\n" + "=" * 80) + print("OVERHEAD ANALYSIS") + print("=" * 80) + print(f"\nLogging Level: {logging_level}") + print(f"Absolute overhead: {overhead_abs:.4f} ms") + print(f"Relative overhead: {overhead_pct:.2f}%") + + print("\n" + "=" * 80) + print("DETAILED STATISTICS") + print("=" * 80) + + print("\nUndecorated (baseline):") + print(f" Mean: {undecorated_mean * 1000:.4f} ms") + print(f" Median: {np.median(self.undecorated_times) * 1000:.4f} ms") + print(f" Std: {undecorated_std * 1000:.4f} ms") + print(f" Min: {np.min(self.undecorated_times) * 1000:.4f} ms") + print(f" Max: {np.max(self.undecorated_times) * 1000:.4f} ms") + + print("\nDecorated (with logging):") + print(f" Mean: {decorated_mean * 1000:.4f} ms") + print(f" Median: {np.median(self.decorated_times) * 1000:.4f} ms") + print(f" Std: {decorated_std * 1000:.4f} ms") + print(f" Min: {np.min(self.decorated_times) * 1000:.4f} ms") + print(f" Max: {np.max(self.decorated_times) * 1000:.4f} ms") + + +def setup_test_inputs( + batch_size: int = 32, + m: int = 512, + n: int = 512, + k: int = 512, + device: str = "cuda:0", +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Set up test inputs for matmul. + + Parameters + ---------- + batch_size : int + Batch size for the matrix multiplication + m, n, k : int + Matrix dimensions + device : str + Device to use + + Returns + ------- + A, B : torch.Tensor + Input tensors for matrix multiplication + """ + # Create random tensors + A = torch.randn(batch_size, m, k, dtype=torch.float16, device=device) + B = torch.randn(batch_size, k, n, dtype=torch.float16, device=device) + + return A, B + + +def warmup(func, A, B, num_warmup: int = 10): + """Warmup the GPU and JIT compilation.""" + for _ in range(num_warmup): + _ = func(A, B) + torch.cuda.synchronize() + + +def benchmark_function( + func, func_name: str, A, B, num_iterations: int = 100 +) -> List[float]: + """ + Benchmark a specific function. + + Parameters + ---------- + func : callable + Function to benchmark + func_name : str + Name of the function (for display) + A, B : torch.Tensor + Input tensors for matrix multiplication + num_iterations : int + Number of iterations to run + + Returns + ------- + List[float] + List of execution times in seconds + """ + print(f"\nBenchmarking: {func_name}") + print(f" Running {num_iterations} iterations...") + + times = [] + + for _ in range(num_iterations): + # Synchronize before timing + torch.cuda.synchronize() + + # Time the execution + start = time.perf_counter() + _ = func(A, B) + torch.cuda.synchronize() + end = time.perf_counter() + + elapsed = end - start + times.append(elapsed) + + print(f" Complete. Mean time: {np.mean(times) * 1000:.4f} ms") + + return times + + +def main(): + """Main benchmark function.""" + print("=" * 80) + print("FlashInfer API Logging Overhead Benchmark") + print("=" * 80) + + # Display logging configuration + print("\nLogging Configuration:") + print(f" FLASHINFER_LOGLEVEL = {LOGGING_LEVEL}") + print(f" FLASHINFER_LOGDEST = {LOG_DEST}") + + # Get level name + level_names = { + 0: "No logging (zero-overhead)", + 1: "Function name only", + 3: "Name + inputs/outputs + metadata", + 5: "Name + inputs/outputs + metadata + statistics", + } + print(f" Level description: {level_names.get(LOGGING_LEVEL, 'Unknown')}") + + # Check if CUDA is available + if not torch.cuda.is_available(): + print("\nError: CUDA is not available. This benchmark requires a CUDA device.") + exit(1) + + device = "cuda:0" + print(f"\nDevice: {device}") + print(f"Device Name: {torch.cuda.get_device_name(device)}") + + # Setup test inputs + print("\nSetting up test inputs...") + batch_size = 32 + m, n, k = 128, 128, 128 + print(f" Batch size: {batch_size}") + print(f" Matrix dimensions: [{batch_size}, {m}, {k}] @ [{batch_size}, {k}, {n}]") + + A, B = setup_test_inputs(batch_size, m, n, k, device) + + # Benchmark parameters + num_iterations = 100 + print("\nBenchmark parameters:") + print(f" Iterations: {num_iterations}") + print(" Warmup iterations: 10") + + # Clear log file before starting + if os.path.exists(LOG_DEST): + os.remove(LOG_DEST) + + print("\n" + "=" * 80) + print("WARMUP PHASE") + print("=" * 80) + + # Warmup undecorated version + print("\nWarming up undecorated version...") + warmup(test_matmul_undecorated, A, B, num_warmup=10) + print(" Complete.") + + # Warmup decorated version + print("\nWarming up decorated version...") + warmup(test_matmul_decorated, A, B, num_warmup=10) + print(" Complete.") + + print("\n" + "=" * 80) + print("BENCHMARK PHASE") + print("=" * 80) + + # Store results + results = BenchmarkResults() + + # Benchmark undecorated version + undecorated_times = benchmark_function( + test_matmul_undecorated, "Undecorated (baseline)", A, B, num_iterations + ) + results.set_undecorated(undecorated_times) + + # Benchmark decorated version + decorated_times = benchmark_function( + test_matmul_decorated, + f"Decorated (logging level {LOGGING_LEVEL})", + A, + B, + num_iterations, + ) + results.set_decorated(decorated_times) + + # Print summary + results.print_summary(LOGGING_LEVEL) + + # Check log file size + if LOGGING_LEVEL > 0 and os.path.exists(LOG_DEST): + log_size = os.path.getsize(LOG_DEST) + print("\n" + "=" * 80) + print("LOG FILE INFO") + print("=" * 80) + print(f"Log file: {LOG_DEST}") + print(f"Log size: {log_size / 1024:.2f} KB ({log_size} bytes)") + print(f"Iterations logged: {num_iterations}") + print(f"Bytes per iteration: {log_size / num_iterations:.2f}") + + # Cleanup option + cleanup_log = os.environ.get("CLEANUP_LOG", "true").lower() == "true" + if cleanup_log: + os.remove(LOG_DEST) + print("\n Log file removed (set CLEANUP_LOG=false to keep it)") + else: + print(f"\n Log file preserved at {LOG_DEST}") + + print("\n" + "=" * 80) + print("RECOMMENDATIONS") + print("=" * 80) + print("\nTo benchmark other levels, run:") + for level in [0, 1, 3, 5]: + if level != LOGGING_LEVEL: + print(f" FLASHINFER_LOGLEVEL={level} python {sys.argv[0]}") + + print("\n" + "=" * 80) + print("Benchmark complete!") + print("=" * 80) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\n\nBenchmark interrupted by user.") + except Exception as e: + print(f"\n\nError during benchmark: {e}") + import traceback + + traceback.print_exc() diff --git a/docs/index.rst b/docs/index.rst index 6a5a9c6a19..f4e61d26c4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,6 +15,7 @@ FlashInfer is a library and kernel generator for Large Language Models that prov :caption: Get Started installation + logging .. toctree:: :maxdepth: 2 diff --git a/docs/logging.rst b/docs/logging.rst new file mode 100644 index 0000000000..c3c2c83d8f --- /dev/null +++ b/docs/logging.rst @@ -0,0 +1,118 @@ +.. _logging: + +Logging +======= + +FlashInfer provides a logging feature to help debug issues and reproduce crashes. This document describes all available logging levels and their features. + +Quick Start +----------- + +Enable logging using two environment variables: + +.. code-block:: bash + + # Set logging level (0-5) + export FLASHINFER_LOGLEVEL=3 + + # Set log destination (default is stdout) + export FLASHINFER_LOGDEST=stdout # or stderr, or a file path like "flashinfer.log" + +Logging Levels +-------------- + +.. list-table:: + :header-rows: 1 + :widths: 10 20 35 25 + + * - Level + - Name + - Features + - Use Case + * - **0** + - Disabled (Default) + - No logging (zero overhead) + - Production + * - **1** + - Function Names + - Function names only + - Basic tracing + * - **3** + - Inputs/Outputs + - Function names + arguments + outputs with metadata + - Standard debugging + * - **5** + - Statistics + - Level 3 + tensor statistics (min, max, mean, NaN/Inf counts) + - Numerical analysis + +Environment Variables +--------------------- + +Main Configuration +^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :header-rows: 1 + :widths: 30 15 15 40 + + * - Variable + - Type + - Default + - Description + * - ``FLASHINFER_LOGLEVEL`` + - int + - 0 + - Logging level (0, 1, 3, 5) + * - ``FLASHINFER_LOGDEST`` + - str + - ``stdout`` + - Log destination: ``stdout``, ``stderr``, or file path + +Process ID Substitution +^^^^^^^^^^^^^^^^^^^^^^^^ + +Use ``%i`` in file paths for automatic process ID substitution (useful for multi-GPU training): + +.. code-block:: bash + + export FLASHINFER_LOGDEST="flashinfer_log_%i.txt" # โ†’ flashinfer_log_12345.txt + + +Miscellaneous Notes and Examples +--------------------------------- + +CUDA Graph Compatibility +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Level 5 statistics are **automatically skipped during CUDA graph capture** to avoid synchronization issues. + +.. code-block:: python + + # This works correctly - no synchronization errors + with torch.cuda.graph(cuda_graph): + result = mm_fp4(a, b, scales, ...) # Level 5 logging active + # Statistics automatically skipped during capture + +Output shows: ``[statistics skipped: CUDA graph capture in progress]`` + +Process IDs for Multi-GPU Environments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + # Use %i for process ID substitution + export FLASHINFER_LOGLEVEL=3 + export FLASHINFER_LOGDEST="logs/flashinfer_api_%i.log" + + torchrun --nproc_per_node=8 awesome_script_that_uses_FlashInfer.py + + # Creates separate logs: + # logs/flashinfer_api_12345.log (rank 0) + # logs/flashinfer_api_12346.log (rank 1) + # ... + +Level 0 has zero overhead +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead. diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py new file mode 100644 index 0000000000..734d6bae28 --- /dev/null +++ b/flashinfer/api_logging.py @@ -0,0 +1,565 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import enum +import functools +import inspect +import logging +import os +import sys +from typing import Any, Callable +import contextlib +import torch + + +# Helper function to substitute %i with process ID in file paths +def _substitute_process_id(path: str) -> str: + """ + Replace %i with the current process ID in a path. + + This is useful for multi-process/multi-GPU environments where each process + needs its own log file. + """ + if "%i" in path: + return path.replace("%i", str(os.getpid())) + return path + + +# Read environment variables once at module load time +_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) +_API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_LOGDEST", "stdout")) + +# Create logger using Python's logging library +_logger = logging.getLogger("flashinfer.api") + + +def _setup_logger(): + """Set up the logger based on environment variables.""" + if _API_LOG_LEVEL == 0: + # Completely disable logging for zero overhead + _logger.addHandler(logging.NullHandler()) + _logger.setLevel(logging.CRITICAL + 1) # Higher than any level + return + + # All enabled levels use loggging.DEBUG; verbosity is controlled by FLASHINFER_LOGLEVEL instead + _logger.setLevel(logging.DEBUG) + + # Remove any existing handlers + _logger.handlers.clear() + + # Create handler based on destination + if _API_LOG_DEST == "stdout": + handler = logging.StreamHandler(sys.stdout) + elif _API_LOG_DEST == "stderr": + handler = logging.StreamHandler(sys.stderr) + else: + handler = logging.FileHandler(_API_LOG_DEST, mode="a") + + # Use a simple formatter (we'll add timestamps manually to key lines) + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + + _logger.addHandler(handler) + _logger.propagate = False # Don't propagate to root logger + + +# Initialize logger at module load time +_setup_logger() + + +def _get_timestamp() -> str: + """Get current timestamp in the format [YYYY-MM-DD HH:MM:SS].""" + from datetime import datetime + + return datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") + + +def _log_system_info(): + """Log system information once at module initialization.""" + if _API_LOG_LEVEL == 0: + return + + lines = [] + lines.append("=" * 80) + lines.append(f"{_get_timestamp()} FlashInfer API Logging - System Information") + lines.append("=" * 80) + + try: + # FlashInfer version + try: + from .version import __version__ as flashinfer_version + + lines.append(f"FlashInfer version: {flashinfer_version}") + except Exception: + lines.append("FlashInfer version: ") + + # CUDA toolkit version + cuda_version = torch.version.cuda + if cuda_version: + lines.append(f"CUDA toolkit version: {cuda_version}") + else: + lines.append("CUDA toolkit version: ") + + # cuDNN version + try: + if torch.backends.cudnn.is_available(): + cudnn_version = torch.backends.cudnn.version() + if cudnn_version: + lines.append(f"cuDNN version: {cudnn_version}") + else: + lines.append("cuDNN version: ") + else: + lines.append("cuDNN version: ") + except Exception as e: + lines.append(f"cuDNN version: ") + + # GPU information (if CUDA is available) + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + lines.append(f"Number of GPUs: {device_count}") + + # Log information for each GPU + for i in range(device_count): + try: + gpu_name = torch.cuda.get_device_name(i) + capability = torch.cuda.get_device_capability(i) + sm_arch = capability[0] * 10 + capability[1] + lines.append(f" GPU {i}: {gpu_name}") + lines.append( + f" Compute capability: {capability[0]}.{capability[1]} (SM{sm_arch})" + ) + except Exception as e: + lines.append(f" GPU {i}: ") + else: + lines.append("CUDA: Not available (CPU-only mode)") + + # PyTorch version + lines.append(f"PyTorch version: {torch.__version__}") + + except Exception as e: + lines.append(f"Error gathering system information: {e}") + + lines.append("=" * 80) + lines.append("") # Empty line for readability + + _logger.debug("\n".join(lines)) + + +# Log system information once at module load time (if logging is enabled) +_log_system_info() + + +def _format_value(value: Any, level: int, indent: int = 0) -> str: + """ + Format a value for logging based on the log level. + + Parameters + ---------- + value : Any + The value to format + level : int + The logging level (1, 2, or 3) + indent : int + The indentation level for nested structures + + Returns + ------- + str + Formatted string representation of the value + """ + indent_str = " " * indent + + # Handle None + if value is None: + return f"{indent_str}None" + + # Handle Enum types + if isinstance(value, enum.Enum): + # Show both the name and value of the enum + return ( + f"{indent_str}{value.__class__.__name__}.{value.name} (value={value.value})" + ) + + # Handle torch.Tensor + if isinstance(value, torch.Tensor): + if level == 1: + return f"{indent_str}Tensor(...)" + + # Level 3+: Show metadata + lines = [f"{indent_str}Tensor("] + lines.append(f"{indent_str} shape={tuple(value.shape)}") + lines.append(f"{indent_str} stride={tuple(value.stride())}") + lines.append(f"{indent_str} dtype={value.dtype}") + lines.append(f"{indent_str} device={value.device}") + lines.append(f"{indent_str} requires_grad={value.requires_grad}") + lines.append(f"{indent_str} is_contiguous={value.is_contiguous()}") + + # Level 5: Add statistics + if level >= 5: + try: + # Skip statistics if we're in CUDA graph capture mode + # (operations like .min()/.max()/.mean() cause synchronization issues) + is_capturing = False + if value.is_cuda and hasattr(torch.cuda, "is_current_stream_capturing"): + with contextlib.suppress(Exception): + is_capturing = torch.cuda.is_current_stream_capturing() + + if is_capturing: + lines.append( + f"{indent_str} [statistics skipped: CUDA graph capture in progress]" + ) + elif value.numel() > 0: + # Convert to float for statistics if possible + if value.dtype in [ + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ]: + val_float = value.float() + lines.append(f"{indent_str} min={val_float.min().item():.6f}") + lines.append(f"{indent_str} max={val_float.max().item():.6f}") + lines.append( + f"{indent_str} mean={val_float.mean().item():.6f}" + ) + nan_count = torch.isnan(val_float).sum().item() + lines.append(f"{indent_str} nan_count={nan_count}") + inf_count = torch.isinf(val_float).sum().item() + lines.append(f"{indent_str} inf_count={inf_count}") + elif value.dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + ]: + lines.append(f"{indent_str} min={value.min().item()}") + lines.append(f"{indent_str} max={value.max().item()}") + lines.append( + f"{indent_str} mean={value.float().mean().item():.6f}" + ) + except Exception as e: + lines.append(f"{indent_str} [statistics error: {e}]") + + lines.append(f"{indent_str})") + return "\n".join(lines) + + # Handle FP4Tensor (custom FlashInfer type) + if hasattr(value, "__class__") and value.__class__.__name__ == "FP4Tensor": + if level == 1: + return f"{indent_str}FP4Tensor(...)" + + lines = [f"{indent_str}FP4Tensor("] + lines.append( + f"{indent_str} data={_format_value(value.data, level, indent + 1)}" + ) + lines.append( + f"{indent_str} scale={_format_value(value.scale, level, indent + 1)}" + ) + lines.append(f"{indent_str} scale_start_index={value.scale_start_index}") + if hasattr(value, "original_shape") and value.original_shape is not None: + lines.append(f"{indent_str} original_shape={value.original_shape}") + lines.append(f"{indent_str})") + return "\n".join(lines) + + # Handle lists + if isinstance(value, list): + if len(value) == 0: + return f"{indent_str}[]" + if level == 1: + return f"{indent_str}[list with {len(value)} items]" + + lines = [f"{indent_str}["] + for i, item in enumerate(value): + lines.append( + f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}" + ) + lines.append(f"{indent_str}]") + return "\n".join(lines) + + # Handle tuples + if isinstance(value, tuple): + if len(value) == 0: + return f"{indent_str}()" + if level == 1: + return f"{indent_str}(tuple with {len(value)} items)" + + lines = [f"{indent_str}("] + for i, item in enumerate(value): + lines.append( + f"{indent_str} [{i}]: {_format_value(item, level, indent + 1)}" + ) + lines.append(f"{indent_str})") + return "\n".join(lines) + + # Handle dictionaries + if isinstance(value, dict): + if len(value) == 0: + return f"{indent_str}{{}}" + if level == 1: + return f"{indent_str}{{dict with {len(value)} keys}}" + + lines = [f"{indent_str}{{"] + for key, val in value.items(): + lines.append( + f"{indent_str} {repr(key)}: {_format_value(val, level, indent + 1)}" + ) + lines.append(f"{indent_str}}}") + return "\n".join(lines) + + # Handle numeric types (int, float, bool) + if isinstance(value, (int, float, bool, complex)): + return f"{indent_str}{value}" + + # Handle strings + if isinstance(value, str): + return f"{indent_str}{repr(value)}" + + # Default: use repr + try: + return f"{indent_str}{repr(value)}" + except Exception: + return f"{indent_str}<{type(value).__name__} object>" + + +def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict: + """ + Extract parameters that have default values but were not explicitly provided. + + Parameters + ---------- + func : Callable + The function being called + args : tuple + Positional arguments that were provided + kwargs : dict + Keyword arguments that were provided + + Returns + ------- + dict + Dictionary of parameter names to default values for parameters that were not provided + """ + try: + sig = inspect.signature(func) + default_params = {} + + # Determine which parameters were NOT provided + for i, (param_name, param) in enumerate(sig.parameters.items()): + # Skip if parameter has no default + if param.default is inspect.Parameter.empty: + continue + + # Check if this parameter was provided + provided = False + + # Check positional args and keyword args + if i < len(args) or param_name in kwargs: + provided = True + + # If not provided, record the default value + if not provided: + default_params[param_name] = param.default + + return default_params + except Exception: + # If we can't inspect the signature, return empty dict + return {} + + +def _log_function_inputs( + func: Callable, func_name: str, args: tuple, kwargs: dict, level: int +) -> None: + """ + Log function inputs BEFORE execution for crash safety. + + This ensures inputs are captured even if the function crashes with a CUDA error. + + Parameters + ---------- + func : Callable + The function being called (needed to extract default parameters) + func_name : str + Name of the function being called + args : tuple + Positional arguments + kwargs : dict + Keyword arguments + level : int + Logging level (3 or 5) + """ + lines = [] + lines.append("=" * 80) + lines.append(f"{_get_timestamp()} FlashInfer API Call: {func_name}") + lines.append("-" * 80) + + # Log explicitly provided inputs + if args or kwargs: + # Positional arguments + if args: + lines.append("Positional input arguments:") + for i, arg in enumerate(args): + lines.append(f" arg[{i}]:") + lines.append(_format_value(arg, level, indent=2)) + + # Keyword arguments + if kwargs: + lines.append("Keyword input arguments:") + for key, value in kwargs.items(): + lines.append(f" {key}=") + lines.append(_format_value(value, level, indent=2)) + else: + lines.append("(No explicit arguments)") + + # Log default parameters that were not explicitly provided + default_params = _get_default_params(func, args, kwargs) + if default_params: + lines.append("Default parameters (not explicitly provided):") + for param_name, default_value in default_params.items(): + lines.append(f" {param_name}= [DEFAULT]") + lines.append(_format_value(default_value, level, indent=2)) + + _logger.debug("\n".join(lines)) + + +def _log_function_outputs(func_name: str, result: Any, level: int) -> None: + """ + Log function outputs AFTER successful execution. + + Parameters + ---------- + func_name : str + Name of the function + result : Any + Function return value + level : int + Logging level (3 or 5) + """ + lines = [] + # Log outputs + lines.append("Output value:") + lines.append(_format_value(result, level, indent=1)) + + lines.append("=" * 80) + lines.append("") # Empty line for readability + + _logger.debug("\n".join(lines)) + + +def flashinfer_api(func: Callable = None) -> Callable: + """ + Decorator to FlashInfer's APIs. + + Currently logs input and output values of the function using Python's logging library. + This decorator integrates with Python's standard logging infrastructure while + maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0). + + NOTE/TODO: Not all FlashInfer APIs are decorated with this decorator yet. This is a work in progress. + + Environment Variables + --------------------- + FLASHINFER_LOGLEVEL : int (default: 0) + - 0: No logging (zero overhead - decorator returns original function) + - 1: Log function name only (logged BEFORE execution - crash-safe) + - 3: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) + - 5: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) + + FLASHINFER_LOGDEST : str (default: "stdout") + - "stdout": Log to standard output + - "stderr": Log to standard error + - : Log to specified file path + - Use %i in path for process ID substitution (e.g., "log_%i.txt" -> "log_12345.txt") + + Examples + -------- + Basic usage: + + >>> @flashinfer_api + ... def my_function(x, y): + ... return x + y + + Notes + ----- + - Key header lines include a timestamp in the format: [YYYY-MM-DD HH:MM:SS] + (e.g., "FlashInfer API Call: function_name", "FlashInfer API Logging - System Information") + - When FLASHINFER_LOGLEVEL=0, the decorator has truly zero overhead + as it returns the original function unchanged. + - Function names and inputs are logged BEFORE execution: + - Level 1: Function name only + - Levels 3-5: Function name + inputs with metadata + This means critical debugging information is preserved even if the function + crashes (e.g., CUDA illegal memory access, out-of-bounds, etc.). + - Outputs are logged AFTER successful execution for levels 3 and 5. + - **CUDA Graph Compatibility**: At level 5, tensor statistics (min/max/mean/nan_count) + are automatically skipped during CUDA graph capture to avoid synchronization issues. + The message "[statistics skipped: CUDA graph capture in progress]" will be logged. + - The %i pattern is automatically replaced with the process ID for multi-process environments. + - The logger does not propagate to the root logger to avoid duplicate logs. + """ + # If logging is disabled, return original function with zero overhead + if _API_LOG_LEVEL == 0: + if func is None: + return lambda f: f + return func + + def decorator(f: Callable) -> Callable: + @functools.wraps(f) + def wrapper(*args, **kwargs): + # Determine function name (with class name if applicable) + func_name = f.__name__ + if args and hasattr(args[0], "__class__"): + try: + class_name = args[0].__class__.__name__ + if "Wrapper" in class_name or class_name in [ + "BatchMLAPagedAttentionWrapper" + ]: + func_name = f"{class_name}.{func_name}" + except Exception: + pass + + # Log BEFORE execution (crash-safe for all levels!) + try: + if _API_LOG_LEVEL == 1: + # Level 1: Just log function name before execution (crash-safe) + _logger.debug( + f"{_get_timestamp()} FlashInfer API Call: {func_name}" + ) + elif _API_LOG_LEVEL >= 3: + # Level 3+: Log full inputs before execution (crash-safe) + _log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL) + except Exception as e: + _logger.error(f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}") + + # Call the original function (may crash here with CUDA errors) + result = f(*args, **kwargs) + + # Log outputs AFTER successful execution (level 3+ only) + try: + if _API_LOG_LEVEL >= 3: + # Level 3+: Log outputs (inputs were already logged above) + _log_function_outputs(func_name, result, _API_LOG_LEVEL) + except Exception as e: + _logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}") + + return result + + return wrapper + + if func is None: + return decorator + return decorator(func) diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py index 6ef13b997f..195ca2d49d 100644 --- a/flashinfer/cudnn/decode.py +++ b/flashinfer/cudnn/decode.py @@ -3,6 +3,7 @@ import torch +from ..api_logging import flashinfer_api from .utils import get_cudnn_fmha_gen_module try: @@ -252,6 +253,7 @@ def _batch_decode_with_kv_cache( return out +@flashinfer_api def cudnn_batch_decode_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index fc573cf7cb..b8c09a66ee 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -3,6 +3,7 @@ import torch +from ..api_logging import flashinfer_api from .utils import get_cudnn_fmha_gen_module try: @@ -383,6 +384,7 @@ def _batch_prefill_with_kv_cache( return out, None +@flashinfer_api def cudnn_batch_prefill_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index af8dda0345..ab34ba8857 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -21,6 +21,7 @@ import torch +from .api_logging import flashinfer_api from .xqa import xqa, xqa_mla from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .jit import ( @@ -312,6 +313,7 @@ def get_trtllm_gen_fmha_module(): return op +@flashinfer_api def single_decode_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -388,6 +390,7 @@ def single_decode_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... +@flashinfer_api def single_decode_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -646,6 +649,7 @@ class BatchDecodeWithPagedKVCacheWrapper: manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -809,6 +813,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api def plan( self, indptr: torch.Tensor, @@ -1162,6 +1167,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api def run( self, q: torch.Tensor, @@ -2059,6 +2065,7 @@ def _fake_paged_run( ) +@flashinfer_api def trtllm_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2332,6 +2339,7 @@ def trtllm_batch_decode_with_kv_cache( # xqa uses NHD layout +@flashinfer_api def xqa_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -2516,6 +2524,7 @@ def _check_trtllm_gen_mla_shape( ) +@flashinfer_api def trtllm_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, @@ -2677,6 +2686,7 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError(f"Backend {backend} not supported") +@flashinfer_api def xqa_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 3c5e7a09c5..7b53c3f82c 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from ..api_logging import flashinfer_api from ..autotuner import ( AutoTuner, DynamicTensorSpec, @@ -685,6 +686,7 @@ def _fake_cutlass_fused_moe( # ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121 +@flashinfer_api def cutlass_fused_moe( input: torch.Tensor, token_selected_experts: torch.Tensor, @@ -1857,6 +1859,7 @@ def _fake_trtllm_fp4_block_scale_moe( ) +@flashinfer_api def trtllm_bf16_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -1937,6 +1940,7 @@ def trtllm_bf16_moe( ) +@flashinfer_api def trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2010,6 +2014,7 @@ def trtllm_fp8_per_tensor_scale_moe( ) +@flashinfer_api def trtllm_fp8_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2087,6 +2092,7 @@ def trtllm_fp8_block_scale_moe( ) +@flashinfer_api def trtllm_fp4_block_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -2216,6 +2222,7 @@ def trtllm_fp4_block_scale_moe( ) +@flashinfer_api def trtllm_fp4_block_scale_routed_moe( topk_ids: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 589c651aca..251e2a4682 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -22,6 +22,7 @@ from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm import torch +from ..api_logging import flashinfer_api from ..autotuner import ( AutoTuner, ConstraintSpec, @@ -539,6 +540,7 @@ def forward( ) +@flashinfer_api def tgv_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -884,6 +886,7 @@ def reset_workspace_buffer( self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer + @flashinfer_api def run( self, x: torch.Tensor, @@ -1551,6 +1554,7 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): return (tuple(block_scale_shape), tuple(block_scale_stride)) +@flashinfer_api def mm_fp8( a: torch.Tensor, b: torch.Tensor, @@ -2024,6 +2028,7 @@ def _heuristic_func_mm_fp4( common_check=_check_mm_fp4_problem_size, heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends ) +@flashinfer_api def mm_fp4( a: torch.Tensor, b: torch.Tensor, @@ -2281,6 +2286,7 @@ def _heuristic_func_bmm_fp8( common_check=_check_bmm_fp8_problem_size, heuristic_func=_heuristic_func_bmm_fp8, ) +@flashinfer_api def bmm_fp8( A: torch.Tensor, B: torch.Tensor, @@ -2372,6 +2378,7 @@ def bmm_fp8( return out +@flashinfer_api def gemm_fp8_nt_groupwise( a: torch.Tensor, b: torch.Tensor, @@ -2623,6 +2630,7 @@ def forward( ) +@flashinfer_api def gemm_fp8_nt_blockscaled( a: torch.Tensor, b: torch.Tensor, @@ -2651,6 +2659,7 @@ def gemm_fp8_nt_blockscaled( ) +@flashinfer_api def group_gemm_fp8_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k) @@ -2813,6 +2822,7 @@ def group_gemm_fp8_nt_groupwise( return out +@flashinfer_api def group_gemm_mxfp8_mxfp4_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k // 2) @@ -2980,6 +2990,7 @@ def get_deepgemm_sm100_module(): return module +@flashinfer_api def group_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (m, k) b: torch.Tensor, # (batch_size, n, k) @@ -3110,6 +3121,7 @@ def group_deepgemm_fp8_nt_groupwise( return out +@flashinfer_api def batch_deepgemm_fp8_nt_groupwise( a: torch.Tensor, # (batch_size, m, k) b: torch.Tensor, # (batch_size, n, k) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index da57d94e6b..22cf029a2e 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -19,6 +19,7 @@ import torch +from .api_logging import flashinfer_api from .jit import gen_batch_mla_module from .jit.mla import gen_mla_module from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend @@ -129,6 +130,7 @@ class BatchMLAPagedAttentionWrapper: torch.Size([114, 128, 512]) """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -199,6 +201,7 @@ def __init__( else: self._backend = backend + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -333,6 +336,7 @@ def run( return_lse_base_on_e: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api def run( self, q_nope: torch.Tensor, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 47d725c5d3..a2c4ceb0a8 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -22,6 +22,7 @@ import torch +from .api_logging import flashinfer_api from .jit import ( gen_batch_prefill_module, gen_customize_batch_prefill_module, @@ -873,6 +874,7 @@ def _fake_paged_run( ) +@flashinfer_api def single_prefill_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -957,6 +959,7 @@ def single_prefill_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... +@flashinfer_api def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -1325,6 +1328,7 @@ class BatchPrefillWithPagedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -1520,6 +1524,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -1976,6 +1981,7 @@ def run( window_left: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api def run( self, q: torch.Tensor, @@ -2350,6 +2356,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, @@ -2493,6 +2500,7 @@ def reset_workspace_buffer( pin_memory=True, ) + @flashinfer_api def plan( self, qo_indptr: torch.Tensor, @@ -2837,6 +2845,7 @@ def run( enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... + @flashinfer_api def run( self, q: torch.Tensor, @@ -3193,6 +3202,7 @@ def get_trtllm_gen_fmha_module(): return op +@flashinfer_api def trtllm_ragged_attention_deepseek( query: torch.Tensor, key: torch.Tensor, @@ -3327,6 +3337,7 @@ def trtllm_ragged_attention_deepseek( return out +@flashinfer_api def trtllm_batch_context_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py new file mode 100644 index 0000000000..6ead5e7d6b --- /dev/null +++ b/tests/utils/test_logging.py @@ -0,0 +1,588 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import sys +import tempfile +from enum import Enum +from pathlib import Path + +import pytest +import torch + + +# Test enum classes +class TestEnum(Enum): + """Test enum with integer values.""" + + OPTION_A = 0 + OPTION_B = 1 + OPTION_C = 2 + + +class StringEnum(Enum): + """Test enum with string values. Names are for testing purposes.""" + + MODE_STANDARD = "standard" + MODE_OPTIMIZED = "optimized" + + +class TestAPILogging: + """Test suite for FlashInfer API logging infrastructure.""" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Reset environment and reimport logging module for each test.""" + # Store original environment + original_level = os.environ.get("FLASHINFER_LOGLEVEL") + original_dest = os.environ.get("FLASHINFER_LOGDEST") + + yield + + # Restore original environment + if original_level is not None: + os.environ["FLASHINFER_LOGLEVEL"] = original_level + elif "FLASHINFER_LOGLEVEL" in os.environ: + del os.environ["FLASHINFER_LOGLEVEL"] + + if original_dest is not None: + os.environ["FLASHINFER_LOGDEST"] = original_dest + elif "FLASHINFER_LOGDEST" in os.environ: + del os.environ["FLASHINFER_LOGDEST"] + + # Force reimport to pick up new environment variables + if "flashinfer.api_logging" in sys.modules: + del sys.modules["flashinfer.api_logging"] + + def setup_logging(self, level: int, dest: str = "stdout"): + """Helper to set up logging environment and reimport.""" + os.environ["FLASHINFER_LOGLEVEL"] = str(level) + os.environ["FLASHINFER_LOGDEST"] = dest + + # Force reimport + if "flashinfer.api_logging" in sys.modules: + del sys.modules["flashinfer.api_logging"] + + from flashinfer.api_logging import flashinfer_api + + return flashinfer_api + + def test_level_0_zero_overhead(self): + """Test that level 0 has truly zero overhead (returns original function).""" + decorator = self.setup_logging(level=0) + + def original_func(x, y): + return x + y + + decorated_func = decorator(original_func) + + # At level 0, decorator should return the original function unchanged + assert decorated_func is original_func + assert decorated_func(5, 3) == 8 + + def test_level_1_function_name(self): + """Test that level 1 logs function name only.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + @decorator + def test_function(x, y): + return x + y + + result = test_function(10, 20) + assert result == 30 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + assert "FlashInfer API Call: test_function" in log_contents + # Level 1 should not log inputs/outputs details + assert "Positional input arguments" not in log_contents + assert "Output value" not in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_level_3_inputs_outputs(self): + """Test that level 3 logs inputs and outputs with metadata.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(tensor, value): + return tensor * value + + tensor = torch.tensor([1.0, 2.0, 3.0]) + test_function(tensor, 2.0) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log function name + assert "FlashInfer API Call: test_function" in log_contents + + # Should log inputs + assert "Positional input arguments" in log_contents + assert "arg[0]" in log_contents + assert "Tensor(" in log_contents + assert "shape=(3,)" in log_contents + assert "dtype=torch.float32" in log_contents + + # Should log outputs + assert "Output value:" in log_contents + + # Should NOT log statistics (level 5 only) + assert "min=" not in log_contents + assert "max=" not in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_level_5_statistics(self): + """Test that level 5 logs tensor statistics.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=5, dest=log_file) + + @decorator + def test_function(tensor): + return tensor + 1.0 + + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + test_function(tensor) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log statistics + assert "min=" in log_contents + assert "max=" in log_contents + assert "mean=" in log_contents + assert "nan_count=" in log_contents + assert "inf_count=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_enum_logging(self): + """Test that enum values are logged with name and value.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(mode: TestEnum, strategy: StringEnum): + return f"{mode.name}_{strategy.name}" + + test_function(TestEnum.OPTION_B, StringEnum.MODE_OPTIMIZED) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should show enum name and value + assert "TestEnum.OPTION_B" in log_contents + assert "(value=1)" in log_contents + assert "StringEnum.MODE_OPTIMIZED" in log_contents + assert ( + "(value=optimized)" in log_contents + or "(value='optimized')" in log_contents + or '(value="optimized")' in log_contents + ) + finally: + Path(log_file).unlink(missing_ok=True) + + def test_default_parameters(self): + """Test that default parameters are logged separately.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(x, y=10, z=20, mode=TestEnum.OPTION_A): + return x + y + z + + # Call with only required argument + result = test_function(5) + assert result == 35 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should show default parameters section + assert "Default parameters (not explicitly provided)" in log_contents + assert "[DEFAULT]" in log_contents + + # Should show the default values + assert "y=" in log_contents + assert "z=" in log_contents + assert "mode=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_explicit_vs_default_parameters(self): + """Test that explicitly provided parameters are not shown in defaults.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(x, y=10, z=20): + return x + y + z + + # Call with some explicit parameters + test_function(5, y=100) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # y should be in keyword arguments (explicit) + assert "Keyword input arguments:" in log_contents + + # Only z should be in defaults + lines = log_contents.split("\n") + default_section_started = False + defaults_found = [] + for line in lines: + if "Default parameters" in line: + default_section_started = True + if default_section_started and "=" in line and "[DEFAULT]" in line: + defaults_found.append(line) + + # Should have only one default parameter (z) + assert len(defaults_found) == 1 + assert "z=" in defaults_found[0] + finally: + Path(log_file).unlink(missing_ok=True) + + def test_class_method_logging(self): + """Test that class methods log with class name.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + class TestWrapper: + @decorator + def run(self, x): + return x * 2 + + wrapper = TestWrapper() + result = wrapper.run(5) + assert result == 10 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log class name for Wrapper classes + assert "TestWrapper.run" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_crash_safety_inputs_logged_before_execution(self): + """Test that inputs are logged BEFORE execution (crash-safe).""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def crashing_function(x, y): + raise RuntimeError("Simulated crash") + + # Call the function and expect it to crash + with pytest.raises(RuntimeError, match="Simulated crash"): + crashing_function(42, 99) + + # Check that inputs were still logged + with open(log_file, "r") as f: + log_contents = f.read() + + # Inputs should be in the log even though function crashed + assert "FlashInfer API Call: crashing_function" in log_contents + assert "Positional input arguments" in log_contents + assert "arg[0]" in log_contents + assert "42" in log_contents + assert "arg[1]" in log_contents + assert "99" in log_contents + + # Outputs should NOT be in the log (function crashed) + assert "Output value:" not in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_different_data_types(self): + """Test logging of various data types.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function( + int_val, + float_val, + bool_val, + str_val, + list_val, + tuple_val, + dict_val, + none_val, + ): + return "success" + + test_function( + 42, 3.14, True, "hello", [1, 2, 3], (4, 5, 6), {"key": "value"}, None + ) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log all types correctly + assert "42" in log_contents + assert "3.14" in log_contents + assert "True" in log_contents + assert "'hello'" in log_contents + assert "None" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_tensor_metadata(self): + """Test that tensor metadata is logged correctly.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(tensor): + return tensor + + # Create a tensor with specific properties + tensor = torch.randn(2, 3, 4, dtype=torch.float32, device="cpu") + tensor = tensor.contiguous() + tensor.requires_grad = False + + test_function(tensor) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log all metadata + assert "shape=(2, 3, 4)" in log_contents + assert "dtype=torch.float32" in log_contents + assert "device=cpu" in log_contents + assert "requires_grad=False" in log_contents + assert "is_contiguous=True" in log_contents + assert "stride=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_nested_structures(self): + """Test logging of nested data structures.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(nested): + return nested + + # Create nested structure + nested = { + "list": [1, 2, 3], + "dict": {"inner": "value"}, + "tuple": (4, 5), + } + + test_function(nested) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should handle nested structures + assert "list" in log_contents + assert "dict" in log_contents + assert "tuple" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_decorator_with_and_without_parentheses(self): + """Test that decorator works both as @decorator and @decorator().""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + # Without parentheses + @decorator + def func1(x): + return x + 1 + + # With parentheses + @decorator() + def func2(x): + return x + 2 + + result1 = func1(10) + result2 = func2(20) + + assert result1 == 11 + assert result2 == 22 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + assert "func1" in log_contents + assert "func2" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + def test_multiple_calls_same_function(self): + """Test that multiple calls to the same function are all logged.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=1, dest=log_file) + + @decorator + def test_function(x): + return x + + # Call multiple times + for i in range(3): + test_function(i) + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should have 3 log entries + assert log_contents.count("FlashInfer API Call: test_function") == 3 + finally: + Path(log_file).unlink(missing_ok=True) + + def test_kwargs_logging(self): + """Test that keyword arguments are logged correctly.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=3, dest=log_file) + + @decorator + def test_function(a, b, c): + return a + b + c + + # Call with keyword arguments + result = test_function(a=1, b=2, c=3) + assert result == 6 + + # Check log contents + with open(log_file, "r") as f: + log_contents = f.read() + + # Should log keyword arguments + assert "Keyword input arguments:" in log_contents + assert "a=" in log_contents + assert "b=" in log_contents + assert "c=" in log_contents + finally: + Path(log_file).unlink(missing_ok=True) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_cuda_graph_compatibility(self): + """Test that level 5 logging is compatible with CUDA graph capture.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f: + log_file = f.name + + try: + decorator = self.setup_logging(level=5, dest=log_file) + + @decorator + def test_cuda_function(tensor): + return tensor * 2.0 + + # Create a CUDA tensor + tensor = torch.randn(10, 10, device="cuda") + + # Test 1: Normal execution (should have statistics) + test_cuda_function(tensor) + + with open(log_file, "r") as f: + log_normal = f.read() + + # Should have statistics in normal execution + # (unless PyTorch version is too old) + if hasattr(torch.cuda, "is_current_stream_capturing"): + # Normal execution should have min/max OR statistics error + has_stats = "min=" in log_normal or "statistics error" in log_normal + assert has_stats, "Expected statistics or error in normal execution" + + # Clear log file + with open(log_file, "w") as f: + f.write("") + + # Test 2: CUDA graph capture (should skip statistics) + if hasattr(torch.cuda, "CUDAGraph"): + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + test_cuda_function(tensor) + + with open(log_file, "r") as f: + log_capture = f.read() + + # Should skip statistics during capture + assert ( + "[statistics skipped: CUDA graph capture in progress]" + in log_capture + or "statistics" not in log_capture + ), "Expected statistics to be skipped during CUDA graph capture" + finally: + Path(log_file).unlink(missing_ok=True) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 5e110040cb103fad08e3947485b5d532dd83a995 Mon Sep 17 00:00:00 2001 From: Raayan Dhar <58057652+raayandhar@users.noreply.github.com> Date: Fri, 21 Nov 2025 23:26:02 -0800 Subject: [PATCH 080/130] fix: add a check for int32 indices in sampling.py (#2127) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description New function to validate that the indices type, when provided, is `int32`. To close https://github.com/flashinfer-ai/flashinfer/issues/2115. There are now two separate functions doing checking in this file. I will move them to the C++ side later when I have some more bandwidth, probably after Thanksgiving. Just a short fix for now. You can close if you'd rather wait for that. ## ๐Ÿ” Related Issues https://github.com/flashinfer-ai/flashinfer/issues/2115 Relevant to the issue. Now running their code: ``` (flashinfer) raayan@uril-1:~/projects/flashinfer$ python test.py tensor([1, 1, 0, 0], device='cuda:0', dtype=torch.int32) Traceback (most recent call last): File "/home/raayan/projects/flashinfer/test.py", line 15, in incorrect_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 1031, in top_k_top_p_sampling_from_logits _check_indices_dtype(indices) File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 487, in _check_indices_dtype raise ValueError(f"indices must have dtype torch.int32, got {indices.dtype}") ValueError: indices must have dtype torch.int32, got torch.int64 ``` ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Improvements** * Enforced that indices passed to sampling operations must use int32, adding runtime validation before sampling. * **Documentation** * Clarified docstrings to state the int32 requirement for indices parameters. * **Tests** * Updated and expanded tests to cover the new dtype validation paths and related error cases. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- flashinfer/sampling.py | 27 ++++++++++++++++++++------- tests/utils/test_sampling.py | 25 ++++++++++++++++++------- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 3ac6367ff5..a7f334a01a 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -481,6 +481,12 @@ def _to_tensor_scalar_tuple(x): return (None, x) +def _check_indices_dtype(indices: Optional[torch.Tensor]) -> None: + """Validate indices dtype.""" + if indices is not None and indices.dtype != torch.int32: + raise ValueError(f"indices must have dtype torch.int32, got {indices.dtype}") + + def _check_tensor_param(param: Any, tensor: torch.Tensor) -> None: """Validate sampling parameters.""" if isinstance(param, torch.Tensor): @@ -576,7 +582,7 @@ def sampling_from_logits( shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique probability distributions. indices: Optional[torch.Tensor] - Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in logits. + Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in logits. For example, if indices[i] = j, then the i-th output will be sampled from logits[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of logits. @@ -612,6 +618,7 @@ def sampling_from_logits( if check_nan: if torch.any(torch.isnan(logits)): raise ValueError("Input logits contains NaN.") + _check_indices_dtype(indices) return get_sampling_module().sampling_from_logits( logits, indices, deterministic, generator ) @@ -634,7 +641,7 @@ def sampling_from_probs( shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique probability distributions. indices: Optional[torch.Tensor] - Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. @@ -676,6 +683,7 @@ def sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") + _check_indices_dtype(indices) return get_sampling_module().sampling_from_probs( probs, indices, deterministic, generator ) @@ -708,7 +716,7 @@ def top_p_sampling_from_probs( If a float, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] - Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. @@ -758,6 +766,7 @@ def top_p_sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") + _check_indices_dtype(indices) _check_tensor_param(top_p, probs) return get_sampling_module().top_p_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator @@ -791,7 +800,7 @@ def top_k_sampling_from_probs( If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] - Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. @@ -841,6 +850,7 @@ def top_k_sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") + _check_indices_dtype(indices) _check_tensor_param(top_k, probs) return get_sampling_module().top_k_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator @@ -875,7 +885,7 @@ def min_p_sampling_from_probs( If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] - Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. @@ -920,6 +930,7 @@ def min_p_sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") + _check_indices_dtype(indices) _check_tensor_param(min_p, probs) return get_sampling_module().min_p_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator @@ -960,7 +971,7 @@ def top_k_top_p_sampling_from_logits( If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] - Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. @@ -1018,6 +1029,7 @@ def top_k_top_p_sampling_from_logits( top_k_mask_logits top_p_sampling_from_probs """ + _check_indices_dtype(indices) _check_tensor_param(top_k, logits) _check_tensor_param(top_p, logits) if filter_apply_order == "top_k_first": @@ -1082,7 +1094,7 @@ def top_k_top_p_sampling_from_probs( If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. indices: Optional[torch.Tensor] - Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs. For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. This allows reusing the same probability distribution for multiple outputs. If indices is not provided, the i-th output will be sampled from the i-th row of probs. @@ -1135,6 +1147,7 @@ def top_k_top_p_sampling_from_probs( top_p_renorm_probs top_k_mask_logits """ + _check_indices_dtype(indices) _check_tensor_param(top_k, probs) _check_tensor_param(top_p, probs) if filter_apply_order == "top_k_first": diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index 9e72c4f49b..4ce93914f4 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -572,7 +572,7 @@ def test_chain_speculative_sampling( @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) -def test_check_tensor_param_min_p(batch_size, vocab_size, p): +def test_tensor_validation_min_p(batch_size, vocab_size, p): pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) @@ -587,7 +587,7 @@ def test_check_tensor_param_min_p(batch_size, vocab_size, p): flashinfer.sampling.min_p_sampling_from_probs( normalized_prob, torch.tensor( - [[p] * vocab_size] * batch_size, dtype=torch.int, device="cuda:0" + [[p] * vocab_size] * batch_size, dtype=torch.float32, device="cuda:0" ), ) @@ -597,22 +597,33 @@ def test_check_tensor_param_min_p(batch_size, vocab_size, p): match=r"Expected a 1D tensor of shape \(batch_size,\) or scalar.*got a 0-dimensional tensor", ): flashinfer.sampling.min_p_sampling_from_probs( - normalized_prob, torch.tensor(p, dtype=torch.int, device="cuda:0") + normalized_prob, torch.tensor(p, dtype=torch.float32, device="cuda:0") ) - # 4: 1D tensor with a broken batch size raises error (only when batch_size > 1). + # 4: non-int32 indices raises error. + with pytest.raises( + ValueError, + match=r"indices must have dtype torch\.int32, got torch\.int64", + ): + flashinfer.sampling.min_p_sampling_from_probs( + normalized_prob, + torch.tensor([p] * batch_size, dtype=torch.float32, device="cuda:0"), + torch.tensor([p] * batch_size, dtype=torch.int64, device="cuda:0"), + ) + + # 5: 1D tensor with a broken batch size raises error (only when batch_size > 1). if batch_size > 1: with pytest.raises( ValueError, match="Sampling parameter tensor batch size mismatch" ): flashinfer.sampling.min_p_sampling_from_probs( - normalized_prob, torch.tensor([p], dtype=torch.int, device="cuda:0") + normalized_prob, torch.tensor([p], dtype=torch.float32, device="cuda:0") ) - # 5: 1D tensor with the correct batch size works. + # 6: 1D tensor with the correct batch size works. samples = flashinfer.sampling.min_p_sampling_from_probs( normalized_prob, - torch.tensor([p] * batch_size, dtype=torch.int, device="cuda:0"), + torch.tensor([p] * batch_size, dtype=torch.float32, device="cuda:0"), ) assert samples.shape == (batch_size,) From 84df81ed8a817f23bc6d430c49dc719fb9376222 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 21 Nov 2025 23:30:09 -0800 Subject: [PATCH 081/130] update autotuner input tensor random range (#2116) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Update autotuner input tensor random range from [0,1) to [-5,5) for larger range and closer to real tensor ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Bug Fixes** * Improved tensor initialization used during autotuning: values are now drawn from a symmetric range around zero ([-5, 5]) with a more uniform-like distribution, yielding more consistent and stable parameter tuning results. โœ๏ธ Tip: You can customize this high-level summary in your review settings. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- flashinfer/autotuner.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 6b6a0c3e48..9f5fb67489 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -61,9 +61,9 @@ def __post_init__(self): # Set default tensor_initializers if not provided if self.tensor_initializers is None: self.tensor_initializers = [ - lambda shapes, dtype, device: torch.randn(shapes, device=device).to( - dtype - ) + lambda shapes, dtype, device: ( + torch.rand(shapes, device=device) * 10 - 5 + ).to(dtype) for _ in range(len(self.input_idx)) ] @@ -761,8 +761,8 @@ def _create_tensor_like( def _prepare_input_tensors( self, profile: OptimizationProfile, inputs: List[torch.Tensor] ) -> List[torch.Tensor]: - default_initializer = lambda shapes, dtype, device: torch.rand( - shapes, device=device + default_initializer = lambda shapes, dtype, device: ( + torch.rand(shapes, device=device) * 10 - 5 ).to(dtype) tensors = [] for i, p in enumerate(profile.shapes): From 2439a416597c7ff8eede5ebf83991288077bcb2d Mon Sep 17 00:00:00 2001 From: qsang-nv <200703406+qsang-nv@users.noreply.github.com> Date: Sat, 22 Nov 2025 15:54:29 +0800 Subject: [PATCH 082/130] enable xqa speculative decoding (#2105) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Enable xqa with speculative decoding and add mask tensor in trtllm_batch_decode_with_kv_cache. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Speculative decoding: multi-token query support (q_seq_len) with optional attention mask threaded end-to-end. * **API** * Public APIs updated to accept q_seq_len and an optional mask; automatic reshaping and runtime checks for multi-token decoding. * **JIT / Build** * JIT now emits SPEC_DEC-enabled variants and includes spec-dec flags in generated specs. * **Backend / Runtime** * Mask propagation and architecture-aware backend selection improved for compatible kernels. * **Tests** * Added helpers and tests to generate causal masks and validate multi-token speculative decoding. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> Co-authored-by: yzh119 --- csrc/flashinfer_xqa_binding.cu | 8 +-- csrc/xqa/mha.cu | 27 ++++++++ csrc/xqa/xqa_wrapper.cu | 18 +++-- flashinfer/decode.py | 16 ++++- flashinfer/jit/xqa.py | 36 +++++++--- flashinfer/xqa.py | 45 ++++++++++-- tests/attention/test_trtllm_gen_attention.py | 72 ++++++++++++++++++-- tests/attention/test_xqa_batch_decode.py | 71 ++++++++++++++++++- 8 files changed, 257 insertions(+), 36 deletions(-) diff --git a/csrc/flashinfer_xqa_binding.cu b/csrc/flashinfer_xqa_binding.cu index dc06614763..8bcbafafd6 100644 --- a/csrc/flashinfer_xqa_binding.cu +++ b/csrc/flashinfer_xqa_binding.cu @@ -34,11 +34,9 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK tvm::ffi::Optional attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, - tvm::ffi::Optional kvScaleTensor, -#if SPEC_DEC - int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, -#endif - TensorView semaphores, TensorView scratch, bool enable_pdl); + tvm::ffi::Optional kvScaleTensor, int64_t qSeqLen, + tvm::ffi::Optional mask, TensorView semaphores, TensorView scratch, + bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper); diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index af61dc0034..90576367b3 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -1267,6 +1267,23 @@ __device__ inline void addAttentionSinks(ThrdRegRowMax& globalRowSum, } } +#if SPEC_DEC +// SPEC_DEC version: handles head-token mixed layout +__device__ inline void addAttentionSinksSpecDec(ThrdRegRowMax& globalRowSum, + ThrdRegRowMax const globalRowMax, + float const* attentionSinks, uint32_t headGrpSize) { + for (uint32_t i = 0; i < globalRowSum.size; i++) { + uint32_t idxHeadToken = warp_size * i + laneId(); + // In SPEC_DEC, layout is [token0_head0, token0_head1, ..., token1_head0, ...] + // Extract head index from head-token index + uint32_t headIdx = idxHeadToken % headGrpSize; + if (headIdx < headGrpSize && idxHeadToken < rowsPerBlock) { + globalRowSum[i] += expf(attentionSinks[headIdx] - globalRowMax[i]); + } + } +} +#endif + #ifdef NDEBUG __device__ __forceinline__ #else @@ -2169,7 +2186,12 @@ CUBIN_EXPORT __global__ // enabled. if (!isMultiBlock && attentionSinks != nullptr) { // Attention sinks are per head. +#if SPEC_DEC + addAttentionSinksSpecDec(globalRowSum, globalRowMax, + attentionSinks + headGrpSize * idxHeadGrp, headGrpSize); +#else addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp); +#endif } ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum); #if LOW_PREC_OUTPUT @@ -2349,7 +2371,12 @@ CUBIN_EXPORT __global__ } if (attentionSinks != nullptr) { // Attention sinks are per head. +#if SPEC_DEC + addAttentionSinksSpecDec(mergedRowSum, mergedRowMax, + attentionSinks + headGrpSize * idxHeadGrp, headGrpSize); +#else addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp); +#endif } __syncthreads(); rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum)); diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index 3f9d637b42..d5fbabc861 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -56,10 +56,8 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK Optional attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale, Optional kvScaleTensor, -#if SPEC_DEC - int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, -#endif - TensorView semaphores, TensorView scratch, bool enable_pdl) { + int64_t qSeqLen, Optional mask, TensorView semaphores, + TensorView scratch, bool enable_pdl) { auto stream = get_stream(output.device()); float const* attentionSinksPtr = attentionSinks.has_value() ? reinterpret_cast(attentionSinks.value().data_ptr()) @@ -70,13 +68,22 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK float const* kvScalePtr = kvScaleTensor.has_value() ? reinterpret_cast(kvScaleTensor.value().data_ptr()) : nullptr; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer; +#else + auto const mha_func = &launchMHAFlashInfer; +#endif // Extract strides from TensorView (in elements, not bytes) uint64_t kv_stride_page = kCacheVLLM.stride(0); uint64_t kv_stride_token = kCacheVLLM.stride(-3); uint64_t kv_stride_head = kCacheVLLM.stride(-2); +#if SPEC_DEC + MaskType const* maskPtr = + mask.has_value() ? reinterpret_cast(mask.value().data_ptr()) : nullptr; +#endif + mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, qScalePtr, reinterpret_cast(output.data_ptr()), #if LOW_PREC_OUTPUT @@ -89,8 +96,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK reinterpret_cast(seqLen.data_ptr()), batchSize, kvCacheScale, kvScalePtr, #if SPEC_DEC - qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), - reinterpret_cast(mask.data_ptr()), + qSeqLen, nullptr, maskPtr, #endif reinterpret_cast(semaphores.data_ptr()), reinterpret_cast(scratch.data_ptr()), enable_pdl, kv_stride_page, kv_stride_token, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index ab34ba8857..1f682c9844 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2086,6 +2086,7 @@ def trtllm_batch_decode_with_kv_cache( backend: str = "auto", q_len_per_req: Optional[int] = 1, o_scale: Optional[float] = 1.0, + mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters @@ -2156,6 +2157,9 @@ def trtllm_batch_decode_with_kv_cache( o_scale : Optional[float] = 1.0 output scale factor for xqa fp8 output. + mask : Optional[torch.Tensor] = None + causal attention mask for xqa speculative decoding. + Returns ------- out : Union[torch.Tensor, FP4Tensor] @@ -2211,6 +2215,7 @@ def trtllm_batch_decode_with_kv_cache( enable_pdl=enable_pdl, q_len_per_req=q_len_per_req, o_scale=o_scale, + mask=mask, ) elif backend == "trtllm-gen": # Convert NHD layout to HND if necessary (transpose only changes stride, not data) @@ -2356,6 +2361,7 @@ def xqa_batch_decode_with_kv_cache( enable_pdl: bool = None, q_len_per_req: Optional[int] = 1, o_scale: Optional[float] = 1.0, + mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Parameters @@ -2407,6 +2413,9 @@ def xqa_batch_decode_with_kv_cache( o_scale : Optional[float] = 1.0 output scale factor for fp8 output. + mask : Optional[torch.Tensor] = None + causal attention mask for xqa speculative decoding. + Returns ------- out : torch.Tensor @@ -2414,8 +2423,6 @@ def xqa_batch_decode_with_kv_cache( """ enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl - assert q_len_per_req == 1, "xqa not support speculative decoding yet" - if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache else: @@ -2449,6 +2456,9 @@ def xqa_batch_decode_with_kv_cache( kv_scale_value = bmm2_scale * o_scale q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5) + if q_len_per_req > 1: + batch_size = query.shape[0] // q_len_per_req + query = query.view(batch_size, q_len_per_req, query.shape[1], query.shape[2]) query_new = query.unsqueeze(1) seq_lens_new = seq_lens.unsqueeze(1) sinks_new = sinks.reshape(num_kv_heads, -1) if sinks is not None else None @@ -2477,6 +2487,8 @@ def xqa_batch_decode_with_kv_cache( sm_count=sm_count, enable_pdl=enable_pdl, rcp_out_scale=1.0 / o_scale, + q_seq_len=q_len_per_req, + mask=mask, ) return out diff --git a/flashinfer/jit/xqa.py b/flashinfer/jit/xqa.py index 04ab098be2..a35f3711cc 100644 --- a/flashinfer/jit/xqa.py +++ b/flashinfer/jit/xqa.py @@ -28,7 +28,6 @@ "-DBEAM_WIDTH=1", "-DUSE_INPUT_KV=0", "-DUSE_CUSTOM_BARRIER=1", - "-DSPEC_DEC=0", ] @@ -40,6 +39,7 @@ def gen_xqa_module( head_group_ratio: int, use_sliding_window: bool, output_dtype: torch.dtype, + q_seq_len: int = 1, ) -> JitSpec: if input_dtype == torch.float16: flag_input_dtype = ["-DINPUT_FP16=1", "-DDTYPE=__half"] @@ -81,6 +81,16 @@ def gen_xqa_module( else: flag_low_prec_output = ["-DLOW_PREC_OUTPUT=0"] + if q_seq_len > 1: + use_spec_dec = True + if q_seq_len * head_group_ratio <= 32: + flag_spec_dec = ["-DSPEC_DEC=1", f"-DSPEC_Q_SEQ_LEN={q_seq_len}"] + else: + flag_spec_dec = ["-DSPEC_DEC=1"] + else: + flag_spec_dec = ["-DSPEC_DEC=0"] + use_spec_dec = False + compilation_context = CompilationContext() nvcc_flags = compilation_context.get_nvcc_flags_list( supported_major_versions=[9, 10, 11, 12] @@ -89,15 +99,22 @@ def gen_xqa_module( flag_mla_wrapper = ["-DMLA_WRAPPER=0"] + sources = [ + jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu", + jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu", + jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_binding.cu", + ] + + target_archs = compilation_context.TARGET_CUDA_ARCHS + + has_sm90 = any(major == 9 for major, minor in target_archs) + if has_sm90: + sources.append(jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu") + sources.append(jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp") + return gen_jit_spec( - f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}", - [ - jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu", - jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu", - jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp", - jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu", - jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_binding.cu", - ], + f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_use_spec_dec_{use_spec_dec}", + sources, extra_cuda_cflags=xqa_nvcc_flags + sm_nvcc_flags + flag_tokens_per_page @@ -107,6 +124,7 @@ def gen_xqa_module( + flag_head_group_ratio + flag_sliding_window + flag_low_prec_output + + flag_spec_dec + flag_mla_wrapper, extra_ldflags=["-lcuda"], # Add CUDA Driver API library ) diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index bbc4832aac..88f96425d5 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -39,6 +39,7 @@ def get_xqa_module( head_group_ratio: int, use_sliding_window: bool, output_dtype: torch.dtype, + q_seq_len: int, ): module = gen_xqa_module( input_dtype, @@ -48,10 +49,16 @@ def get_xqa_module( head_group_ratio, use_sliding_window, output_dtype, + q_seq_len, ).build_and_load() + if q_seq_len > 1: + use_spec_dec = True + else: + use_spec_dec = False + @register_custom_op( - f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}", + f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_use_spec_dec_{use_spec_dec}", mutates_args=("output", "workspace_buffer"), ) def xqa( @@ -74,6 +81,8 @@ def xqa( semaphores: torch.Tensor, workspace_buffer: torch.Tensor, enable_pdl: bool, + q_seq_len: int, + mask: Optional[torch.Tensor], ) -> None: module.xqa_wrapper( run_sm90_fp8_mha, @@ -94,13 +103,15 @@ def xqa( batch_size, 1.0 if isinstance(kv_scale, torch.Tensor) else kv_scale, None if isinstance(kv_scale, float) else kv_scale, + q_seq_len, + mask, semaphores, workspace_buffer, enable_pdl, ) @register_fake_op( - f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}" + f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_use_spec_dec_{use_spec_dec}" ) def _fake_xqa( run_sm90_fp8_mha: bool, @@ -122,6 +133,8 @@ def _fake_xqa( semaphores: torch.Tensor, workspace_buffer: torch.Tensor, enable_pdl: bool, + q_seq_len: int, + mask: Optional[torch.Tensor], ) -> None: pass @@ -149,12 +162,15 @@ def xqa( sm_count: Optional[int] = None, enable_pdl: Optional[bool] = None, rcp_out_scale: float = 1.0, + q_seq_len: int = 1, + mask: Optional[torch.Tensor] = None, ) -> None: r"""Apply attention with paged KV cache using XQA kernel. Parameters ---------- q : torch.Tensor - Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``. + Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]`` if not using speculative decoding, + or ``[batch_size, beam_width, q_seq_len, num_q_heads, head_dim]`` if using speculative decoding. ``q_seq_len`` is the number of speculative decoding tokens. Data type should be torch.float16 or torch.bfloat16. Now only beam_width 1 is supported. k_cache: torch.Tensor @@ -175,7 +191,7 @@ def xqa( Sequence lengths tensor with shape ``[batch_size, beam_width]``. Data type should be torch.uint32. output : torch.Tensor - Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``. + Output tensor with shape that matches the query tensor. Data type should match query tensor or kv tensor. This tensor will be modified in-place. workspace_buffer : torch.Tensor Workspace buffer for temporary computations. @@ -207,12 +223,19 @@ def xqa( If None, will be set to True if hardware supports it. rcp_out_scale : float, default=1.0 Reciprocal of output scale factor. + q_seq_len : int, default=1 + Query sequence length. When > 1, enables speculative decoding mode. + mask : Optional[torch.Tensor], default=None + Causal attention mask for speculative decoding mode (when ``q_seq_len > 1``). + Shape: ``[batch_size, q_seq_len, mask_size_per_row]`` where + ``mask_size_per_row = ((q_seq_len + 31) // 32) * 2``. + Data type should be torch.uint16 (bit-packed format, aligned to 32 bits). Note ---- The function automatically infers several parameters from tensor shapes: - batch_size from q.shape[0] - - num_q_heads from q.shape[2] + - num_q_heads from q.shape[-2] - head_dim from q.shape[-1] - input_dtype from q.dtype - kv_cache_dtype from k.dtype @@ -227,7 +250,7 @@ def xqa( # Infer parameters from tensors batch_size = q.shape[0] - num_q_heads = q.shape[2] + num_q_heads = q.shape[-2] head_dim = q.shape[-1] # Calculate head_group_ratio @@ -274,7 +297,15 @@ def xqa( head_group_ratio, use_sliding_window, output.dtype, + q_seq_len, ) + + if q_seq_len > 1: + assert mask is not None, "Mask is required for speculative decoding" + run_sm90_fp8_mha = ( + False # TODO: mha_sm90.cu has precision issue with speculative decoding + ) + xqa_module.xqa( run_sm90_fp8_mha, sm_count, @@ -295,6 +326,8 @@ def xqa( semaphores, workspace_buffer, enable_pdl, + q_seq_len, + mask, ) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 4e2c615827..4e2b7aefe0 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -339,6 +339,66 @@ def unpack_compare_nvfp4( return output_unpacked, output_ref +def generate_causal_mask( + batch_size: int, + q_seq_len: int, + device: torch.device, +) -> torch.Tensor: + """ + Generate causal attention mask for speculative decoding. + + Parameters + ---------- + batch_size : int + Batch size + q_seq_len : int + Query sequence length (number of speculative decoding tokens) + device : torch.device + Target device for the mask tensor + + Returns + ------- + torch.Tensor + Causal mask with shape [batch_size, q_seq_len, mask_size_per_row] + where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units). + Data type: torch.uint16 + + """ + num_packed_masks_per_token = (q_seq_len + 31) // 32 + + q_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(1) + kv_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(0) + + causal_bool_mask = kv_indices <= q_indices + + padded_seq_len = num_packed_masks_per_token * 32 + if padded_seq_len > q_seq_len: + padding = torch.zeros( + q_seq_len, padded_seq_len - q_seq_len, device=device, dtype=torch.bool + ) + causal_bool_mask = torch.cat([causal_bool_mask, padding], dim=1) + + causal_bool_mask = causal_bool_mask.view(q_seq_len, num_packed_masks_per_token, 32) + + bit_positions = torch.tensor( + [1 << i for i in range(32)], device=device, dtype=torch.int64 + ) + + mask_uint32 = ( + (causal_bool_mask.to(torch.int64) * bit_positions).sum(dim=-1).to(torch.uint32) + ) + + mask_uint32 = ( + mask_uint32.unsqueeze(0) + .expand(batch_size, q_seq_len, num_packed_masks_per_token) + .contiguous() + ) + + mask_uint16 = mask_uint32.view(torch.uint16) + + return mask_uint16 + + def _test_trtllm_batch_prefill( kv_layout, batch_size, @@ -701,12 +761,6 @@ def _test_trtllm_batch_decode( if backend == "xqa" and q_dtype == "fp8": pytest.skip("xqa backend only supports fp16 and bf16 query") - # xqa backend doesn't support speculative decoding yet - if backend == "xqa" and q_len_per_req > 1: - pytest.skip( - "xqa backend does not support speculative decoding (q_len_per_req > 1) yet" - ) - if o_dtype == "nvfp4" and q_len_per_req > 1: # todo(Yingyi): add support for nvfp4 with speculative decoding pytest.skip("nvfp4 is not supported for q_len_per_req > 1") @@ -826,6 +880,11 @@ def _test_trtllm_batch_decode( kv_indptr=kv_indptr_tokens, ) + if q_len_per_req > 1: + mask = generate_causal_mask(batch_size, q_len_per_req, GPU_DEVICE) + else: + mask = None + # Run decode function call with specified backend bmm1_scale = q_scale * k_scale * sm_scale bmm2_scale = v_scale / o_scale @@ -857,6 +916,7 @@ def _test_trtllm_batch_decode( backend=backend, q_len_per_req=q_len_per_req, o_scale=o_scale, + mask=mask, ) if backend == "trtllm-gen": # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero diff --git a/tests/attention/test_xqa_batch_decode.py b/tests/attention/test_xqa_batch_decode.py index 542e3194bf..6a06575ad6 100644 --- a/tests/attention/test_xqa_batch_decode.py +++ b/tests/attention/test_xqa_batch_decode.py @@ -290,6 +290,66 @@ def get_last_page_len(seq_lens, page_size): return last_page_len +def generate_causal_mask( + batch_size: int, + q_seq_len: int, + device: torch.device, +) -> torch.Tensor: + """ + Generate causal attention mask for speculative decoding. + + Parameters + ---------- + batch_size : int + Batch size + q_seq_len : int + Query sequence length (number of speculative decoding tokens) + device : torch.device + Target device for the mask tensor + + Returns + ------- + torch.Tensor + Causal mask with shape [batch_size, q_seq_len, mask_size_per_row] + where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units). + Data type: torch.uint16 + + """ + num_packed_masks_per_token = (q_seq_len + 31) // 32 + + q_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(1) + kv_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(0) + + causal_bool_mask = kv_indices <= q_indices + + padded_seq_len = num_packed_masks_per_token * 32 + if padded_seq_len > q_seq_len: + padding = torch.zeros( + q_seq_len, padded_seq_len - q_seq_len, device=device, dtype=torch.bool + ) + causal_bool_mask = torch.cat([causal_bool_mask, padding], dim=1) + + causal_bool_mask = causal_bool_mask.view(q_seq_len, num_packed_masks_per_token, 32) + + bit_positions = torch.tensor( + [1 << i for i in range(32)], device=device, dtype=torch.int64 + ) + + mask_uint32 = ( + (causal_bool_mask.to(torch.int64) * bit_positions).sum(dim=-1).to(torch.uint32) + ) + + mask_uint32 = ( + mask_uint32.unsqueeze(0) + .expand(batch_size, q_seq_len, num_packed_masks_per_token) + .contiguous() + ) + + mask_uint16 = mask_uint32.view(torch.uint16) + + return mask_uint16 + + @pytest.mark.skipif( get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12], reason="XQA is only supported on SM90, SM100, SM120 GPUs", @@ -297,6 +357,9 @@ def get_last_page_len(seq_lens, page_size): @pytest.mark.parametrize( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", [ + (4, 4, 64, 4, 2), + (4, 2, 16, 2, 4), + (4, 3, 32, 2, 6), (4, 1, 16, 2, 1), (4, 1, 32, 2, 5), (128, 1, 64, 2, 6), @@ -338,8 +401,6 @@ def test_xqa_batch_decode( This test supports both NHD and HND layouts. """ - if q_len_per_req > 1: - pytest.skip("xqa does not support speculative decoding yet") # Set up test parameters torch.manual_seed(0) @@ -444,6 +505,11 @@ def test_xqa_batch_decode( kv_indptr=kv_indptr_tokens, ) + if q_len_per_req > 1: + mask = generate_causal_mask(batch_size, q_len_per_req, GPU_DEVICE) + else: + mask = None + # Run xqa_batch_decode_with_kv_cache function output = flashinfer.decode.xqa_batch_decode_with_kv_cache( q.contiguous(), @@ -461,6 +527,7 @@ def test_xqa_batch_decode( kv_layout=kv_layout, q_len_per_req=q_len_per_req, o_scale=o_scale, + mask=mask, ) # Verification From d56be0dc6b16668548b9b2d112db7d61f0763ac9 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Sat, 22 Nov 2025 01:55:43 -0600 Subject: [PATCH 083/130] Add custom communicator for trtllm_mnnvl_ar (#2056) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Added optional communication-backend parameter for multi-node memory and buffer allocation to allow using a provided communicator for handle transfer. * **Bug Fixes / Reliability** * Multi-node synchronization now uses the provided communicator's barrier when available, preserving previous behavior otherwise. * **Tests** * Added end-to-end tests covering custom communication backends and multi-node all-reduce synchronization. --- flashinfer/comm/mnnvl.py | 30 +- flashinfer/comm/trtllm_mnnvl_ar.py | 14 +- tests/comm/test_trtllm_mnnvl_allreduce.py | 11 +- ...test_trtllm_mnnvl_allreduce_custom_comm.py | 263 ++++++++++++++++++ 4 files changed, 305 insertions(+), 13 deletions(-) create mode 100644 tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 12aec978ec..2d280a68e8 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -155,6 +155,9 @@ def Get_size(self) -> int: ... @abstractmethod def allgather(self, data: int) -> List[int]: ... + @abstractmethod + def barrier(self) -> None: ... + @abstractmethod def Split(self, color: int, key: int) -> "CommBackend": ... @@ -209,6 +212,9 @@ def Get_size(self) -> int: def allgather(self, data: int) -> List[int]: return self._mpicomm.allgather(data) + def barrier(self): + self._mpicomm.Barrier() + def Split(self, color: int, key: int) -> CommBackend: self._mpicomm = self._mpicomm.Split(color, key) return MPIBackend() # Returns new adapter @@ -555,6 +561,7 @@ def __init__( group_rank: int, device_idx: int, is_multi_node: bool = True, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -631,7 +638,7 @@ def __init__( "[McastDeviceMemory] Device does not support fabric handle." ) - self._alloc_mn_mcast_mem(buf_size) + self._alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer) else: # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem raise NotImplementedError("Single-node NVLS allocation not implemented yet") @@ -753,7 +760,9 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size - def _alloc_mn_mcast_mem(self, buf_size: int): + def _alloc_mn_mcast_mem( + self, buf_size: int, comm_backend_for_handle_transfer: Any = None + ): """Allocate multi-node multicast memory using MNNVL""" # Verify CUDA context @@ -766,10 +775,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) except Exception as e: print(f"Error checking CUDA context: {e}") - - # Get MPI communicator - comm = MpiComm() - + if comm_backend_for_handle_transfer is None: + comm = MpiComm() + else: + comm = comm_backend_for_handle_transfer # Set up allocation properties handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC @@ -969,6 +978,7 @@ def __init__( group_rank: int, device: torch.device, mn_nvlink: bool = True, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): """ Constructor for McastGpuBuffer. @@ -979,9 +989,15 @@ def __init__( group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation mn_nvlink: Flag indicating if multi-node NVLink is used + comm_backend_for_handle_transfer: Communication backend for handle transfer """ self.mcast_device_memory = McastDeviceMemory( - buf_size, group_size, group_rank, device.index, mn_nvlink + buf_size, + group_size, + group_rank, + device.index, + mn_nvlink, + comm_backend_for_handle_transfer, ) self.buf_size = buf_size self.local_device = device diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 76aedee260..84a9c150de 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -15,7 +15,7 @@ from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op -from .mnnvl import McastGPUBuffer +from .mnnvl import McastGPUBuffer, CommBackend def mpi_barrier(): @@ -122,7 +122,10 @@ def trtllm_mnnvl_rmsnorm( def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None + mapping: Mapping, + dtype: torch.dtype, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, + buffer_size_in_bytes: Optional[int] = None, ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -138,6 +141,7 @@ def get_allreduce_mnnvl_workspace( Args: mapping: Tensor parallel mapping configuration containing rank info dtype: Data type of the tensors being reduced + comm: Optional communication backend for multi-node synchronization buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens Returns: @@ -167,6 +171,7 @@ def get_allreduce_mnnvl_workspace( mapping.tp_rank, torch.device("cuda", mapping.local_rank), mapping.is_multi_node() or force_mn, + comm_backend_for_handle_transfer=comm_backend_for_handle_transfer, ) # Initialize the unicast buffer with -0.0 @@ -174,7 +179,10 @@ def get_allreduce_mnnvl_workspace( # CPU barrier since we assume this should not be called in cuda graph torch.cuda.synchronize() - mpi_barrier() + if comm_backend_for_handle_transfer is None: + mpi_barrier() + else: + comm_backend_for_handle_transfer.barrier() # This is a buffer to maintain the state of this allreduce Op # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index abb3795019..e7274c46f0 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -1,5 +1,5 @@ # Check torch version: -from typing import Tuple +from typing import Tuple, Optional import pytest import torch @@ -7,6 +7,7 @@ import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import CommBackend, MpiComm # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm @@ -28,6 +29,7 @@ def row_linear_residual_norm_fusion_forward( unicast_ptr: int, max_num_elements_mnnvl: int, buffer_flags_mnnvl: torch.Tensor, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): x = x.cuda() residual = residual.cuda() @@ -36,8 +38,11 @@ def row_linear_residual_norm_fusion_forward( tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - - MPI.COMM_WORLD.barrier() + if comm_backend_for_handle_transfer is None: + comm = MpiComm() + else: + comm = comm_backend_for_handle_transfer + comm.barrier() def func( input, diff --git a/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py new file mode 100644 index 0000000000..60933cf89b --- /dev/null +++ b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py @@ -0,0 +1,263 @@ +# Check torch version: +from typing import Any, Tuple + +import multiprocessing as mp +import socket +import pytest +import torch +import torch.distributed as dist + +import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar +from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import CommBackend as CommBackend + +import pynvml + +pynvml.nvmlInit() + + +class CustomCommunicator(CommBackend): + def __init__(self, group): + self._group = group + + def Get_rank(self) -> int: + return dist.get_rank(self._group) + + def Get_size(self) -> int: + return dist.get_world_size(self._group) + + def allgather(self, data: int | bytes): + device = f"cuda:{torch.cuda.current_device()}" + if isinstance(data, int): + local_tensor = torch.tensor([data], device=device, dtype=torch.int32) + world_size = self.Get_size() + gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)] + + dist.all_gather(gathered, local_tensor, group=self._group) + return [int(x.item()) for x in gathered] + + elif isinstance(data, bytes): + local_tensor = torch.ByteTensor(list(data)).unsqueeze(0).to(device) + world_size = self.Get_size() + gathered = [data] * self.Get_size() + dist.all_gather_object(gathered, data, group=self._group) + return gathered + else: + raise TypeError(f"Unsupported type for allgather: {type(data)}") + + def bcast(self, data, root: int = 0): + """ + Broadcast a picklable Python object from `root` to all ranks. + Uses torch.distributed.broadcast_object_list under the hood. + + Returns the broadcasted object on every rank. + """ + obj_list = [data] + # broadcast_object_list mutates obj_list in-place + dist.broadcast_object_list(obj_list, src=root, group=self._group) + return obj_list[0] + + def barrier(self): + """ + Synchronize all ranks in this communicator. + """ + dist.barrier(group=self._group) + + def Split(self, color: int, key: int) -> "CustomCommunicator": + return self + + +def get_open_port() -> int: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("::1", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () +) -> None: + mp.set_start_method("spawn", force=True) + + procs = [] + distributed_init_port = get_open_port() + for i in range(world_size): + proc_args = (world_size, i, dtype, distributed_init_port) + target_args + proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") + proc.start() + procs.append(proc) + + for i in range(world_size): + procs[i].join() + assert procs[i].exitcode == 0, ( + f"Process {i} failed with exit code {procs[i].exitcode}" + ) + + +def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidden_size): + # Set CUDA device based on rank + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = dist.group.WORLD + torch.cuda.set_device(rank) + comm = CustomCommunicator(group) + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=world_size, + tp_size=world_size, + ) + + if mapping.local_rank == 0: + print( + f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" + ) + print( + f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" + ) + + tensor_parallel_size = world_size + eps = 1e-5 + torch.manual_seed(42) + + # Track if this rank failed + rank_failed = False + failure_message = "" + + try: + # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list + # This workspace is sized for the maximum expected sequence length and can be reused within each list + # Each parameterized list gets its own fresh workspace allocation + explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * seq_len + mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( + mapping, dtype, comm, explicit_workspace_bytes + ) + ) + + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() + buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() + unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( + mapping.tp_rank + ) + + # Test each sequence length with the same workspace (reusing allocated buffers within this list) + if rank == 0: + print( + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, dtype={dtype}" + ) + + # Generate test data (same on all ranks due to same seed) + x_full = torch.randn( + (tensor_parallel_size, seq_len, hidden_size), + dtype=dtype, + device=torch.device("cuda"), + ) + residual = torch.randn( + (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") + ) + norm_weight = torch.randn( + (hidden_size,), dtype=dtype, device=torch.device("cuda") + ) + + # Each rank gets its slice of the input + x = x_full[rank, :, :] + + # Compute reference output based on fusion mode + reference_output: Tuple[torch.Tensor, ...] = None + + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + + # Run the test with the same workspace + from .test_trtllm_mnnvl_allreduce import row_linear_residual_norm_fusion_forward + + row_linear_residual_norm_fusion_forward( + x, + residual, + norm_weight, + eps, + hidden_size, + dtype, + mapping, + False, + reference_output, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + buffer_flags_mnnvl, + comm, + ) + + # Synchronize before next test + comm.barrier() + + print(f"PASSED[rank={rank}]: seq_len={seq_len}, dtype={dtype}") + + except Exception as e: + rank_failed = True + failure_message = ( + f"FAILED[rank={rank}]: seq_lens={seq_len}, dtype={dtype} failed: {e}" + ) + print(failure_message) + # Gather failure status from all ranks + all_failures = comm.allgather(rank_failed) + + # If any rank failed, fail the test + if any(all_failures): + failed_ranks = [i for i, failed in enumerate(all_failures) if failed] + if rank == 0: + print(f"Test failed on ranks: {failed_ranks}") + + # Fail the test on all ranks + pytest.fail(f"Test failed on ranks {failed_ranks}") + comm.barrier() + + finally: + # Ensure cleanup happens for this list's workspace + if "mcast_buffer_mnnvl" in locals(): + del mcast_buffer_mnnvl + + # Final synchronization and check for failures across all ranks + comm.barrier() + + +"""Main test function that runs on each MPI rank""" + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_mnnvl_allreduce_custom_communicator( + monkeypatch, + world_size, +): + monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + seq_len = 24 + dtype = torch.bfloat16 + hidden_size = 2048 + + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + raise ValueError( + f"world_size {world_size} is greater than available_gpus {available_gpus}" + ) + print(f"Running test for world_size={world_size}") + multi_process_parallel( + world_size, + dtype, + _run_mnnvl_ar, + target_args=(seq_len, hidden_size), + ) + print(f"custom mnnvl allreduce world_size = {world_size}: OK") From cf2df82ae0af179d21525769d55ca50a9a6525a0 Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Sat, 22 Nov 2025 17:37:16 +0100 Subject: [PATCH 084/130] fix: DeepSeek activation uninitialized data (#2128) --- csrc/trtllm_fused_moe_dev_kernel.cu | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index 7a58042041..a19c89638d 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -196,6 +196,8 @@ struct KernelTraits<1> { //////////////////////////////////////////////////////////////////////////////////////////////////// +constexpr int DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA = 128; + template __global__ void activationDeepSeekKernel(KernelParams params) { using Type = typename KernelParams::Type; @@ -203,7 +205,7 @@ __global__ void activationDeepSeekKernel(KernelParams params) { using KernelTraits = KernelTraits; using MaxOp = typename KernelTraits::MaxOp; using PackedType = typename KernelTraits::PackedType; - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ float s_scaleOutArr[NumTokensPerCta]; __shared__ typename BlockReduce::TempStorage tempStorage; @@ -235,6 +237,15 @@ __global__ void activationDeepSeekKernel(KernelParams params) { tokenCtaIdx += gridDim.y * NumTokensPerCta) { for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2; hiddenIdx += blockDim.x * gridDim.x) { +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + scale1Arr[tokenInCtaIdx] = 0.0f; + scale2Arr[tokenInCtaIdx] = 0.0f; + dataX1Arr[tokenInCtaIdx] = 0.0f; + dataX2Arr[tokenInCtaIdx] = 0.0f; + outArr[tokenInCtaIdx] = 0.0f; + absOutArr[tokenInCtaIdx] = 0.0f; + } #pragma unroll for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { int const tokenIdx = tokenCtaIdx + tokenInCtaIdx; @@ -328,7 +339,6 @@ void run(Data const& data, void* stream) { if (data.mUseDeepSeekFp8) { constexpr int NUM_ELTS_PER_LOAD = 1; constexpr int NUM_ELTS_PER_SF = 128; - int const NUM_THREADS_PER_CTA = 128; int device{-1}; cudaGetDevice(&device); @@ -355,8 +365,8 @@ void run(Data const& data, void* stream) { const dim3 grid(gridSizeX, gridSizeY, data.topK); - LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid, NUM_THREADS_PER_CTA, 0, - stream); + LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid, + DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0, stream); } else { int const numThreads = 256; const dim3 grid(data.innerDim / 128, data.topK, data.numTokens); From 9f13e83aabc5b3abf1427e262ea9d39bec08ff33 Mon Sep 17 00:00:00 2001 From: FlashInfer Bot Date: Sun, 23 Nov 2025 23:34:12 -0800 Subject: [PATCH 085/130] chore: Update CODEOWNERS (#2135) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR updates the CODEOWNERS file based on git commit history analysis from the last 180 days. ## Changes - Updated `.github/CODEOWNERS` with current code ownership based on: - Commit frequency - File coverage - Commit recency ## How to Review 1. Review the changes to `.github/CODEOWNERS` 2. Verify that the assigned owners are appropriate for each module 3. Make manual adjustments if needed before merging ## Notes - This is an automated PR generated weekly - Minimum commits threshold: 1 - Analysis period: 180 days - Directory depth: 3 levels - Top N owners per module: 5 --- ๐Ÿค– This PR was automatically generated by the [update-codeowners workflow](.github/workflows/update-codeowners.yml) ## Summary by CodeRabbit * **Chores** * Updated code ownership and review assignments across project directories to optimize approval workflows and access control management. โœ๏ธ Tip: You can customize this high-level summary in your review settings. Co-authored-by: flashinfer-bot Co-authored-by: Claude --- .github/CODEOWNERS | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index fc3b20c491..295ec2a27c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,43 +3,43 @@ # Analysis period: 180 days # Minimum commits threshold: 1 -benchmarks/ @bkryu @jiahanc @cyx-6 @yzh119 @nv-yunzheq +benchmarks/ @bkryu @jiahanc @cyx-6 @kahyunnam @yzh119 benchmarks/routines/ @bkryu @nv-yunzheq @jiahanc @cyx-6 @nvmbreughe ci/ @cyx-6 @yzh119 @nvmbreughe ci/scripts/ @cyx-6 ci/scripts/jenkins/ @cyx-6 csrc/ @wenscarl @yzh119 @cyx-6 @djmmoss @nv-yunzheq -csrc/fused_moe/ @nv-yunzheq @yzh119 @yongwww @djmmoss @cyx-6 +csrc/fused_moe/ @nv-yunzheq @yzh119 @yongwww @cyx-6 @djmmoss csrc/fused_moe/cutlass_backend/ @nv-yunzheq @yzh119 @yongwww @djmmoss @cyx-6 csrc/nv_internal/ @wenscarl @djmmoss @nv-yunzheq @yongwww @cyx-6 csrc/nv_internal/cpp/ @wenscarl @bkryu @yongwww @djmmoss @joker-eph csrc/nv_internal/include/ @wenscarl @nv-yunzheq csrc/nv_internal/tensorrt_llm/ @wenscarl @djmmoss @nv-yunzheq @yongwww @cyx-6 csrc/xqa/ @cyx-6 @yzh119 -docs/ @yzh119 @cyx-6 @wenscarl @nv-yunzheq @aleozlx -flashinfer/ @yzh119 @cyx-6 @nvmbreughe @aleozlx @wenscarl +docs/ @yzh119 @cyx-6 @bkryu @wenscarl @nv-yunzheq +flashinfer/ @yzh119 @cyx-6 @wenscarl @nvmbreughe @aleozlx flashinfer-cubin/ @yzh119 @cyx-6 flashinfer-cubin/flashinfer_cubin/ @yzh119 flashinfer-jit-cache/ @yzh119 @cyx-6 flashinfer-jit-cache/flashinfer_jit_cache/ @yzh119 flashinfer/comm/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @djmmoss -flashinfer/cudnn/ @Anerudhan @yzh119 @cyx-6 @Anerudhan +flashinfer/cudnn/ @Anerudhan @yzh119 @bkryu @cyx-6 @Anerudhan flashinfer/cute_dsl/ @yzh119 @kaixih @Amir-19 @aleozlx -flashinfer/dsv3_ops/ @nvmbreughe -flashinfer/fused_moe/ @djmmoss @jiahanc @yzh119 @cyx-6 @aleozlx -flashinfer/gemm/ @nvmbreughe -flashinfer/jit/ @yzh119 @cyx-6 @aleozlx @jiahanc @nvmbreughe -flashinfer/jit/attention/ @yzh119 @cyx-6 @Anerudhan @joker-eph +flashinfer/dsv3_ops/ @nv-yunzheq @nvmbreughe +flashinfer/fused_moe/ @nv-yunzheq @jiahanc @djmmoss @yzh119 @cyx-6 +flashinfer/gemm/ @nvmbreughe @bkryu +flashinfer/jit/ @yzh119 @cyx-6 @aleozlx @nv-yunzheq @jiahanc +flashinfer/jit/attention/ @yzh119 @cyx-6 @Anerudhan flashinfer/jit/gemm/ @yzh119 @nv-yunzheq @jiahanc flashinfer/logits_processor/ @cyx-6 @yzh119 flashinfer/profiler/ @cyx-6 flashinfer/triton/ @nvmbreughe @cyx-6 flashinfer/tuning_configs/ @kaixih -include/ @yzh119 @jiahanc @nvmbreughe @IwakuraRein @bkryu -include/flashinfer/ @yzh119 @jiahanc @nvmbreughe @IwakuraRein @bkryu +include/ @yzh119 @kahyunnam @jiahanc @IwakuraRein @nv-yunzheq +include/flashinfer/ @yzh119 @kahyunnam @jiahanc @IwakuraRein @nv-yunzheq include/flashinfer/attention/ @yzh119 @kahyunnam @joker-eph include/flashinfer/comm/ @yongwww @nvmbreughe @djmmoss @yzh119 @cyx-6 include/flashinfer/gemm/ @ttyio @yongwww @yzh119 @nvmbreughe @aleozlx -include/flashinfer/trtllm/ @jiahanc @joker-eph @aleozlx @yzh119 @wenscarl +include/flashinfer/trtllm/ @jiahanc @joker-eph @aleozlx @yzh119 @IwakuraRein profiler/ @cyx-6 scripts/ @yzh119 @nvmbreughe @dierksen @yongwww @bkryu From ecd4ef176f729a9ed809ff1aeea57a4e1f2564ab Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 24 Nov 2025 01:09:53 -0800 Subject: [PATCH 086/130] bugfix: fix unittest error introduced in #2056 (#2136) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description In #2056 , when `world_size > available_gpus`, we should skip UT instead of raise error. cc @wenscarl for viz. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Enhanced test execution logic to improve reliability when handling resource constraints. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --- tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py index 60933cf89b..772ceead0b 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py @@ -250,7 +250,7 @@ def test_mnnvl_allreduce_custom_communicator( available_gpus = torch.cuda.device_count() if world_size > available_gpus: - raise ValueError( + pytest.skip( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) print(f"Running test for world_size={world_size}") From efd8554911efcc072e8937826a02231d83ccf62d Mon Sep 17 00:00:00 2001 From: qsang-nv <200703406+qsang-nv@users.noreply.github.com> Date: Tue, 25 Nov 2025 09:03:11 +0800 Subject: [PATCH 087/130] fix flaky xqa test (#2126) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description WIP. Do not merge, see if this could fix xqa flaky test. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Tests** * Default test seed changed to improve reproducibility; tests now use batched K/V handling, batched reference comparisons, expanded sequence-length cases, device-based scaling tensors, seeded shuffling, and batch-level validation with adjusted tolerances. * Over-provisioned GPU runs now skip instead of failing. * **Bug Fixes** * More consistent attention scaling and more robust GPU attention validation across batched and device-based test paths. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> Co-authored-by: Zihao Ye --- csrc/xqa/mha.cu | 4 +- csrc/xqa/mha_sm90.cu | 4 +- csrc/xqa/mla_sm120.cu | 4 +- tests/attention/test_xqa.py | 435 +++++++++++++++++++++--------------- 4 files changed, 259 insertions(+), 188 deletions(-) diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index 90576367b3..c8d6ca2c22 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -1327,8 +1327,8 @@ CUBIN_EXPORT __global__ uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head, uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) { - float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale; - float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale; + float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale; + float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale; assert(allowMultiBlockMode || gridDim.x == 1); bool const isMultiBlock = allowMultiBlockMode && (gridDim.x != 1); uint32_t const nbSubSeqPerSeq = allowMultiBlockMode ? gridDim.x : 1; diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu index 06938edd91..a39b94cc21 100644 --- a/csrc/xqa/mha_sm90.cu +++ b/csrc/xqa/mha_sm90.cu @@ -640,8 +640,8 @@ __launch_bounds__(128 * 3) uint32_t* __restrict__ const semaphores = nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)] void* __restrict__ const scratch = nullptr) { - float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale; - float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale; + float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale; + float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) && \ (IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1 uint32_t const idxReq = blockIdx.z / nbKHeads; diff --git a/csrc/xqa/mla_sm120.cu b/csrc/xqa/mla_sm120.cu index 495d9e94d0..fc77535bfd 100644 --- a/csrc/xqa/mla_sm120.cu +++ b/csrc/xqa/mla_sm120.cu @@ -1564,8 +1564,8 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha PartialResult* __restrict__ const partialResults = nullptr) // [totalNbInputTokens][maxNbSubSeq] { - float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale; - float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale; + float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale; + float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale; assert(blockDim.x == 32 * 12 && blockDim.y == 1 && blockDim.z == 1); extern __shared__ char smemBuf[]; uint32_t const warpRank = makeWarpUniform(this_warp(), threadIdx.x / warp_size); diff --git a/tests/attention/test_xqa.py b/tests/attention/test_xqa.py index b6454de05a..6884998fc8 100644 --- a/tests/attention/test_xqa.py +++ b/tests/attention/test_xqa.py @@ -8,7 +8,7 @@ from flashinfer.utils import get_compute_capability -def set_random_seed(seed=42): +def set_random_seed(seed=0): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -33,8 +33,8 @@ def div_up(a, b): def ref_attention( q, - k_cache, # Changed: now takes full tensor [seq_len, dim] - v_cache, # Changed: now takes full tensor [seq_len, dim] + k_cache, + v_cache, seq_len, q_scale, kv_scale, @@ -42,16 +42,22 @@ def ref_attention( attention_sinks, sliding_win_size, valid_elems_per_head, - valid_elems_per_v_head=None, # Optional: for MLA where V dim != K dim + valid_elems_per_v_head=None, ): """ - For MLA: - - Q/K dimension: 576 (valid_elems_per_head) - - V dimension: 512 (valid_elems_per_v_head) - - Output dimension: matches valid_elems_per_head (576) but only first - valid_elems_per_v_head (512) elements are valid + Batched reference attention implementation. + + Args: + q: [batch_size, nb_k_heads, head_grp_size, valid_elems_per_head] + k_cache: [batch_size, nb_k_heads, seq_len, valid_elems_per_head] + v_cache: [batch_size, nb_k_heads, seq_len, valid_elems_per_v_head] + seq_len: scalar or [batch_size] tensor + attention_sinks: [nb_k_heads, head_grp_size] or None + + Returns: + out: [batch_size, nb_k_heads, head_grp_size, valid_elems_per_v_head] """ - head_grp_size = q.shape[0] + batch_size, nb_k_heads, head_grp_size, _ = q.shape rcp_x_scale = 1.0 / x_scale qk_scale = q_scale * kv_scale / math.sqrt(valid_elems_per_head) @@ -59,21 +65,16 @@ def ref_attention( if valid_elems_per_v_head is None: valid_elems_per_v_head = valid_elems_per_head - q_f32 = q.to(torch.float32) # [head_grp_size, valid_elems_per_head] - - # Directly use the pre-assembled cache tensors - k_cache_f32 = k_cache[:seq_len].to(torch.float32) # [seq_len, valid_elems_per_head] - # For MLA: V cache storage is 576 but only first 512 elements are valid - v_cache_f32 = v_cache[:seq_len, :valid_elems_per_v_head].to( + # Convert to float32 for computation + q_f32 = q.to( torch.float32 - ) # [seq_len, valid_elems_per_v_head] - - # q_f32: [head_grp_size, valid_elems_per_head] - # k_cache_f32: [seq_len, valid_elems_per_head] - # gemm0_acc: [head_grp_size, seq_len] - gemm0_acc = torch.zeros( - head_grp_size, seq_len, dtype=torch.float32, device=q_f32.device - ) + ) # [batch_size, nb_k_heads, head_grp_size, valid_elems_per_head] + k_cache_f32 = k_cache[:, :, :seq_len].to( + torch.float32 + ) # [batch_size, nb_k_heads, seq_len, valid_elems_per_head] + v_cache_f32 = v_cache[:, :, :seq_len, :valid_elems_per_v_head].to( + torch.float32 + ) # [batch_size, nb_k_heads, seq_len, valid_elems_per_v_head] # Calculate sliding window start position if sliding_win_size == 0 or seq_len < sliding_win_size: @@ -81,49 +82,38 @@ def ref_attention( else: seq_beg = seq_len - sliding_win_size - # Set positions before sliding window to negative infinity (masking) + # QยทK^T: [batch_size, nb_k_heads, head_grp_size, seq_len] + gemm0_acc = torch.matmul(q_f32, k_cache_f32.transpose(-2, -1)) * qk_scale + + # Apply sliding window mask if seq_beg > 0: - gemm0_acc[:, :seq_beg] = float("-inf") + gemm0_acc[:, :, :, :seq_beg] = float("-inf") - # q_f32: [head_grp_size, valid_elems_per_head] - # k_cache_f32[seq_beg:seq_len]: [valid_seq_len, valid_elems_per_head] - if seq_beg < seq_len: - valid_k_cache = k_cache_f32[ - seq_beg:seq_len - ] # [valid_seq_len, valid_elems_per_head] - valid_scores = ( - torch.matmul(q_f32, valid_k_cache.t()) * qk_scale - ) # [head_grp_size, valid_seq_len] - gemm0_acc[:, seq_beg:seq_len] = valid_scores + # Softmax + row_max = torch.max(gemm0_acc, dim=-1, keepdim=True)[ + 0 + ] # [batch_size, nb_k_heads, head_grp_size, 1] + x = torch.exp( + gemm0_acc - row_max + ) # [batch_size, nb_k_heads, head_grp_size, seq_len] - row_max = torch.max(gemm0_acc, dim=1, keepdim=True)[0] # [head_grp_size, 1] - x = torch.exp(gemm0_acc - row_max) # [head_grp_size, seq_len] + row_sum = torch.sum( + x, dim=-1, keepdim=True + ) # [batch_size, nb_k_heads, head_grp_size, 1] - row_sum = torch.sum(x, dim=1, keepdim=True) # [head_grp_size, 1] + # Add attention sinks contribution + if attention_sinks is not None: + # attention_sinks: [nb_k_heads, head_grp_size] + # row_max: [batch_size, nb_k_heads, head_grp_size, 1] + sink_weights = torch.exp( + attention_sinks.unsqueeze(0).unsqueeze(-1) - row_max + ) # [batch_size, nb_k_heads, head_grp_size, 1] + row_sum = row_sum + sink_weights x = x * rcp_x_scale - if seq_beg < seq_len: - valid_x = x[:, seq_beg:seq_len] # [head_grp_size, valid_seq_len] - valid_v_cache = v_cache_f32[ - seq_beg:seq_len - ] # [valid_seq_len, valid_elems_per_v_head] - out = torch.matmul( - valid_x, valid_v_cache - ) # [head_grp_size, valid_elems_per_v_head] - else: - out = torch.zeros( - head_grp_size, - valid_elems_per_v_head, - dtype=torch.float32, - device=q_f32.device, - ) - - if attention_sinks is not None: - sink_weights = torch.exp( - attention_sinks - row_max.squeeze(-1) - ) # [head_grp_size] - row_sum.squeeze(-1)[:] += sink_weights + # Attention ยท V: [batch_size, nb_k_heads, head_grp_size, valid_elems_per_v_head] + out = torch.matmul(x, v_cache_f32) out = out * (x_scale * kv_scale) / row_sum @@ -138,7 +128,22 @@ def ref_attention( @pytest.mark.parametrize("use_sliding_window", [True, False]) @pytest.mark.parametrize("input_type", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_attention_sinks", [True, False]) -@pytest.mark.parametrize("seq_len", [2, 15, 256, 514]) +@pytest.mark.parametrize( + "seq_len", + [ + 2, + 15, + 256, + 512, + pytest.param( + 514, + marks=pytest.mark.xfail( + reason="seq_len=514 is known to fail in full test suite occasionally", + strict=False, + ), + ), + ], +) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("nb_k_heads", [2, 4]) @pytest.mark.parametrize("tokens_per_page", [16, 64]) @@ -173,7 +178,7 @@ def test_xqa( q_scale, use_fp8_output, ): - set_random_seed(42) + set_random_seed(0) nb_q_heads = nb_k_heads * head_grp_size @@ -268,7 +273,9 @@ def test_xqa( # Shuffle page indices flattened = page_list_arg.flatten() - indices = torch.randperm(flattened.numel(), device="cuda") + generator = torch.Generator(device="cuda") + generator.manual_seed(42) + indices = torch.randperm(flattened.numel(), generator=generator, device="cuda") shuffled_flat = flattened[indices] page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq) @@ -347,8 +354,8 @@ def test_xqa( nb_k_heads, tokens_per_page, sinks=attention_sinks, - q_scale=q_scale, - kv_scale=kv_cache_scale, + q_scale=torch.tensor(q_scale, device="cuda"), + kv_scale=torch.tensor(kv_cache_scale, device="cuda"), sliding_win_size=sliding_win_size, kv_layout=kv_layout, sm_count=sm_count, @@ -356,76 +363,106 @@ def test_xqa( rcp_out_scale=rcp_out_scale, ) + # Batch reconstruct all K/V caches from paged memory + # [batch_size, nb_k_heads, max_seq_len, valid_elems_per_head] + num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page + batch_k_cache = torch.zeros( + batch_size, + nb_k_heads, + max_seq_len, + valid_elems_per_head, + dtype=input_type, + device="cuda", + ) + batch_v_cache = torch.zeros( + batch_size, + nb_k_heads, + max_seq_len, + valid_elems_per_head, + dtype=input_type, + device="cuda", + ) + for req in range(batch_size): - for b in range(beam_width): - for idx_k_head in range(nb_k_heads): - # Assemble contiguous K/V cache from paged memory using advanced indexing - num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page - pages = page_list_arg[req, :num_pages] # [num_pages] + pages = page_list_arg[req, :num_pages] # [num_pages] + for idx_k_head in range(nb_k_heads): + # Gather all pages at once + if kv_layout == "NHD": + k_pages = cache_k_heads[ + pages, :, idx_k_head, : + ] # [num_pages, tokens_per_page, head_dim] + v_pages = cache_v_heads[pages, :, idx_k_head, :] + else: # HND + k_pages = cache_k_heads[ + pages, idx_k_head, :, : + ] # [num_pages, tokens_per_page, head_dim] + v_pages = cache_v_heads[pages, idx_k_head, :, :] + + # Reshape to contiguous sequence and store + batch_k_cache[req, idx_k_head, : num_pages * tokens_per_page] = ( + k_pages.reshape(-1, valid_elems_per_head) + ) + batch_v_cache[req, idx_k_head, : num_pages * tokens_per_page] = ( + v_pages.reshape(-1, valid_elems_per_head) + ) + + # Reshape q_heads: [batch_size, beam_width, nb_q_heads, dim] -> [batch_size, nb_k_heads, head_grp_size, dim] + # Since beam_width = 1, we can squeeze it + q_reshaped = q_heads.squeeze(1).reshape( + batch_size, nb_k_heads, head_grp_size, valid_elems_per_head + ) - # Gather all pages at once - if kv_layout == "NHD": - # [num_pages, tokens_per_page, nb_k_heads, head_dim] - k_pages = cache_k_heads[ - pages, :, idx_k_head, : - ] # [num_pages, tokens_per_page, head_dim] - v_pages = cache_v_heads[pages, :, idx_k_head, :] - else: # HND - # [num_pages, nb_k_heads, tokens_per_page, head_dim] - k_pages = cache_k_heads[ - pages, idx_k_head, :, : - ] # [num_pages, tokens_per_page, head_dim] - v_pages = cache_v_heads[pages, idx_k_head, :, :] - - # Reshape to contiguous sequence - k_cache = k_pages.reshape( - -1, valid_elems_per_head - ) # [num_pages*tokens_per_page, head_dim] - v_cache = v_pages.reshape(-1, valid_elems_per_head) - - ref_output = ref_attention( - q=q_heads[req][b][ - idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size - ], - k_cache=k_cache, - v_cache=v_cache, - seq_len=seq_len, - q_scale=q_scale, - kv_scale=kv_cache_scale, - x_scale=1.0, - attention_sinks=attention_sinks[idx_k_head, :] - if use_attention_sinks - else None, - sliding_win_size=sliding_win_size if use_sliding_window else 0, - valid_elems_per_head=valid_elems_per_head, - ) - kernel_output = output[req][b][ - idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size - ].to(torch.float32) - if fp8_kv_cache: - atol = 0.05 - rtol = 0.05 - else: - atol = 0.01 - rtol = 0.01 - if use_fp8_output: - ref_output = ref_output * rcp_out_scale - atol = 0.15 - rtol = 0.15 - - diff_abs = torch.abs(ref_output - kernel_output) - diff_rel = diff_abs / (torch.abs(ref_output) + 1e-8) - - within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) - - pass_ratio = within_tolerance.float().mean().item() - - required_ratio = 0.99 - assert pass_ratio >= required_ratio, ( - f"req={req}, b={b}, idx_k_head={idx_k_head}: " - f"Total {ref_output.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, " - f"require at least {required_ratio:.1%}" - ) + # Batch compute reference attention + ref_output_batch = ref_attention( + q=q_reshaped, + k_cache=batch_k_cache, + v_cache=batch_v_cache, + seq_len=seq_len, + q_scale=q_scale, + kv_scale=kv_cache_scale, + x_scale=1.0, + attention_sinks=attention_sinks if use_attention_sinks else None, + sliding_win_size=sliding_win_size if use_sliding_window else 0, + valid_elems_per_head=valid_elems_per_head, + ) # [batch_size, nb_k_heads, head_grp_size, valid_elems_per_head] + + # Reshape kernel output to match: [batch_size, beam_width, nb_q_heads, dim] -> [batch_size, nb_k_heads, head_grp_size, dim] + kernel_output_reshaped = ( + output.squeeze(1) + .reshape(batch_size, nb_k_heads, head_grp_size, valid_elems_per_head) + .to(torch.float32) + ) + + if use_fp8_output: + ref_output_batch = ref_output_batch * rcp_out_scale + + # Set tolerances + if fp8_kv_cache: + atol = 0.05 + rtol = 0.05 + else: + atol = 0.01 + rtol = 0.01 + if use_fp8_output: + atol = 0.15 + rtol = 0.15 + + # Compute differences for all elements at once + diff_abs = torch.abs(ref_output_batch - kernel_output_reshaped) + diff_rel = diff_abs / (torch.abs(ref_output_batch) + 1e-8) + within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) + + # One-shot validation for all elements + total_elements = ref_output_batch.numel() + passing_elements = within_tolerance.sum().item() + pass_ratio = passing_elements / total_elements + required_ratio = 0.99 + + assert pass_ratio >= required_ratio, ( + f"Batch validation failed: " + f"Total {total_elements} elements, only {passing_elements} ({pass_ratio:.1%}) meet tolerance criteria, " + f"require at least {required_ratio:.1%}" + ) @pytest.mark.skipif( @@ -446,7 +483,7 @@ def test_xqa_mla( q_scale, enable_pdl, ): - set_random_seed(42) + set_random_seed(0) # MLA specific constants (fixed, not parameterized) nb_k_heads = 1 # MLA only supports 1 K head @@ -571,54 +608,88 @@ def test_xqa_mla( enable_pdl=enable_pdl, ) + # Batch reconstruct all K/V caches from paged memory + # [batch_size, nb_k_heads, max_seq_len, valid_elems_per_head_qk] + num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page + batch_k_cache = torch.zeros( + batch_size, + nb_k_heads, + max_seq_len, + valid_elems_per_head_qk, + dtype=torch.float32, + device="cuda", + ) + batch_v_cache = torch.zeros( + batch_size, + nb_k_heads, + max_seq_len, + valid_elems_per_head_qk, + dtype=torch.float32, + device="cuda", + ) + for req in range(batch_size): - for b in range(beam_width): - for idx_k_head in range(nb_k_heads): - # Assemble contiguous K/V cache from paged memory using advanced indexing - num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page - pages = page_list_arg[req, :num_pages] # [num_pages] + pages = page_list_arg[req, :num_pages] # [num_pages] + for idx_k_head in range(nb_k_heads): + # NHD layout: [num_pages, tokens_per_page, nb_k_heads, head_dim] + k_pages = cache_k_heads[ + pages, :, idx_k_head, : + ] # [num_pages, tokens_per_page, head_dim] + v_pages = cache_v_heads[pages, :, idx_k_head, :] + + # Reshape to contiguous sequence and store + batch_k_cache[req, idx_k_head, : num_pages * tokens_per_page] = ( + k_pages.reshape(-1, valid_elems_per_head_qk) + ) + batch_v_cache[req, idx_k_head, : num_pages * tokens_per_page] = ( + v_pages.reshape(-1, valid_elems_per_head_qk) + ) + + # Reshape q_heads: [batch_size, beam_width, nb_q_heads, dim] -> [batch_size, nb_k_heads, head_grp_size, dim] + # Since beam_width = 1, we can squeeze it + q_reshaped = q_heads.squeeze(1).reshape( + batch_size, nb_k_heads, head_grp_size, valid_elems_per_head_qk + ) - # NHD layout: [num_pages, tokens_per_page, nb_k_heads, head_dim] - k_pages = cache_k_heads[ - pages, :, idx_k_head, : - ] # [num_pages, tokens_per_page, head_dim] - v_pages = cache_v_heads[pages, :, idx_k_head, :] + # Batch compute reference attention + ref_output_batch = ref_attention( + q=q_reshaped, + k_cache=batch_k_cache, + v_cache=batch_v_cache, + seq_len=seq_len, + q_scale=q_scale * math.sqrt(576), + kv_scale=kv_cache_scale, + x_scale=1.0, + attention_sinks=None, + sliding_win_size=0, + valid_elems_per_head=valid_elems_per_head_qk, # Q/K dimension (576) + valid_elems_per_v_head=valid_elems_per_head_v, # V dimension (512) + ) # [batch_size, nb_k_heads, head_grp_size, valid_elems_per_v_head] + + # Reshape kernel output to match: [batch_size, beam_width, nb_q_heads, valid_elems_per_v_head] -> [batch_size, nb_k_heads, head_grp_size, valid_elems_per_v_head] + kernel_output_reshaped = ( + output.squeeze(1) + .reshape(batch_size, nb_k_heads, head_grp_size, valid_elems_per_head_v) + .to(torch.float32) + ) - # Reshape to contiguous sequence - k_cache = k_pages.reshape(-1, valid_elems_per_head_qk) - v_cache = v_pages.reshape(-1, valid_elems_per_head_qk) - - ref_output = ref_attention( - q=q_heads[req][b][ - idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size - ], - k_cache=k_cache, - v_cache=v_cache, - seq_len=seq_len, - q_scale=q_scale * math.sqrt(576), - kv_scale=kv_cache_scale, - x_scale=1.0, - attention_sinks=None, - sliding_win_size=0, - valid_elems_per_head=valid_elems_per_head_qk, # Q/K dimension (576) - valid_elems_per_v_head=valid_elems_per_head_v, # V dimension (512) - ).to(torch.float32) - kernel_output = output[req][b][ - idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size - ].to(torch.float32) - atol = 0.05 - rtol = 0.05 - - diff_abs = torch.abs(ref_output - kernel_output) - diff_rel = diff_abs / (torch.abs(ref_output) + 1e-8) - - within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) - - pass_ratio = within_tolerance.float().mean().item() - - required_ratio = 0.95 - assert pass_ratio >= required_ratio, ( - f"req={req}, b={b}, idx_k_head={idx_k_head}: " - f"Total {ref_output.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, " - f"require at least {required_ratio:.1%}" - ) + # Set tolerances + atol = 0.05 + rtol = 0.05 + + # Compute differences for all elements at once + diff_abs = torch.abs(ref_output_batch - kernel_output_reshaped) + diff_rel = diff_abs / (torch.abs(ref_output_batch) + 1e-8) + within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) + + # One-shot validation for all elements + total_elements = ref_output_batch.numel() + passing_elements = within_tolerance.sum().item() + pass_ratio = passing_elements / total_elements + required_ratio = 0.95 + + assert pass_ratio >= required_ratio, ( + f"Batch validation failed: " + f"Total {total_elements} elements, only {passing_elements} ({pass_ratio:.1%}) meet tolerance criteria, " + f"require at least {required_ratio:.1%}" + ) From fd5273ce37334f8b085e15defe5a6fdb63a750d7 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Wed, 26 Nov 2025 03:05:16 +0800 Subject: [PATCH 088/130] fix: some bugs of headDim 256 trtllm-gen fmha kernels. (#2137) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This MR updates the trtllm-gen cubins which fix several bugs of headDim 256 fmha kernels. ## ๐Ÿ” Related Issues https://github.com/flashinfer-ai/flashinfer/issues/1993 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Chores** * Updated artifact references and checksums for TRT-LLM FMHA components. * **Tests** * Parameterized attention tests to run with head dimensions 128 and 256; removed the expected failure for the 256-bit decode path so it now runs normally. * Modified a communication test to skip when requested world size exceeds available GPUs instead of erroring. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --------- Co-authored-by: Zihao Ye Co-authored-by: yzh119 --- flashinfer/artifacts.py | 4 ++-- tests/attention/test_trtllm_gen_attention.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index cfb2862e47..b520023b70 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -87,7 +87,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "1e49deb33ec20018ae0acf1d956a579578069da1/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "9f1b6ddaa1592a8339a82fcab7d27a57eff445fd/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988" ) @@ -107,7 +107,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "66757498f573430583d63b04c02bf9e38306eefe2ce31df9b5d923d99bd15d84" + "a5a60600a80076317703695f56bbef2f0a44075ef4e24d7b06ba67ff68bc9da2" ) TRTLLM_GEN_BMM: str = ( "85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf" diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 4e2b7aefe0..dd0002ff06 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -414,13 +414,13 @@ def _test_trtllm_batch_prefill( max_q_len, max_kv_len, device_scale, + head_dim, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") # Set up test parameters torch.manual_seed(0) - head_dim = 128 # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size @@ -639,6 +639,7 @@ def _test_trtllm_batch_prefill( @pytest.mark.parametrize("enable_sink", [True, False]) @pytest.mark.parametrize("max_q_len", [511]) @pytest.mark.parametrize("max_kv_len", [2047]) +@pytest.mark.parametrize("head_dim", [128, 256]) def test_trtllm_batch_prefill( kv_layout, batch_size, @@ -653,6 +654,7 @@ def test_trtllm_batch_prefill( enable_sink, max_q_len, max_kv_len, + head_dim, ): _test_trtllm_batch_prefill( kv_layout, @@ -669,6 +671,7 @@ def test_trtllm_batch_prefill( max_q_len, max_kv_len, kv_dtype == "fp8", + head_dim, ) @@ -690,6 +693,7 @@ def test_trtllm_batch_prefill( @pytest.mark.parametrize("enable_sink", [False]) @pytest.mark.parametrize("max_q_len", [8192]) @pytest.mark.parametrize("max_kv_len", [8192]) +@pytest.mark.parametrize("head_dim", [128, 256]) def test_trtllm_batch_prefill_bs1( kv_layout, batch_size, @@ -704,6 +708,7 @@ def test_trtllm_batch_prefill_bs1( enable_sink, max_q_len, max_kv_len, + head_dim, ): _test_trtllm_batch_prefill( kv_layout, @@ -720,6 +725,7 @@ def test_trtllm_batch_prefill_bs1( max_q_len, max_kv_len, False, + head_dim, ) @@ -1202,7 +1208,6 @@ def test_trtllm_batch_decode_head_dim_256( device_scale, ): # Small number of test cases for head_dim = 256 - pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") _test_trtllm_batch_decode( "trtllm-gen", kv_layout, From aeeccac5dfd740999cecb1b0247341469018f7d6 Mon Sep 17 00:00:00 2001 From: YAMY <74099316+YAMY1234@users.noreply.github.com> Date: Tue, 25 Nov 2025 11:05:32 -0800 Subject: [PATCH 089/130] =?UTF-8?q?fix(trtllm):=20reset=20negative=20strid?= =?UTF-8?q?eBatch=20to=200=20for=20ragged=20KV=20layout=20to=20=E2=80=A6?= =?UTF-8?q?=20(#2134)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Fix TMA descriptor failures for ragged KV layouts in the TRT-LLM FMHA path. When using `trtllm_ragged_attention_deepseek` with a non-paged, non-contiguous KV layout (ragged KV), the KV batch dimension is effectively collapsed (shape `[head_dim, sum_seq_lens_kv, num_heads_kv, 1]`). However, `kStrideBatch` / `vStrideBatch` can be set to a large `numel()`-based value that overflows a 32-bit `int` and becomes a negative sentinel. This negative stride is then interpreted as a `uint64_t` in `buildNdTmaDescriptor`, producing an enormous `strideInBytes` and causing `cuTensorMapEncodeTiled` to fail with: > Error: Failed to initialize the TMA descriptor due to invalid argument This PR updates `makeStrideKv` so that for ragged KV layouts (`!isPagedKv && !isContiguousKv`), any negative `strideBatch` is treated as a sentinel and clamped to `0`. This matches the actual memory layout (no real batch stride for the collapsed batch dimension) and prevents overflow in the TMA descriptor. ## ๐Ÿ” Related Issues - SGLang: DeepSeek-R1 + `trtllm_ragged_attention_deepseek` on SM100 with long KV (e.g., 128k total KV tokens) failing at `buildNdTmaDescriptor` with invalid TMA configuration. [[Bug] [DeepSeek-R1] Error: Failed to initialize the TMA descriptor due to invalid argument on B200](https://github.com/sgl-project/sglang/issues/13775) ## Pull Request Checklist ### Pre-commit Checks - [x] I have installed `pre-commit` (e.g., `pip install pre-commit`). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run `pre-commit run --all-files` and fixed any reported issues. ## Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (unittest, etc.). ## Reviewer Notes - The change is intentionally minimal and only affects ragged KV layouts. - The goal is to keep the existing device kernels intact while fixing the host-side TMA descriptor construction for this layout. ## Summary by CodeRabbit * **Bug Fixes** * Prevented negative batch stride for ragged key/value cache layouts, avoiding invalid descriptor errors in large or irregular attention workloads. * **Tests** * Added an integration test that reproduces large ragged KV scenarios to verify the clamped stride behavior on CUDA systems. * Updated a communication test to skip when required GPUs aren't available instead of failing, improving test robustness. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --------- Co-authored-by: Zihao Ye Co-authored-by: yzh119 --- include/flashinfer/trtllm/fmha/kernelParams.h | 6 + .../attention/test_trtllm_ragged_kv_stride.py | 118 ++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 tests/attention/test_trtllm_ragged_kv_stride.py diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index c184ad9e10..bc45832968 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -331,6 +331,12 @@ struct KernelParams { strideHeads = options.vStrideHeads; strideBatch = options.vStrideBatch; } + + // Ragged layout has no batch stride; reset negative overflow to 0 for TMA descriptor. + if (!isPagedKv(options.mQkvLayout) && !isContiguousKv(options.mQkvLayout) && strideBatch < 0) { + strideBatch = 0; + } + // The 3 strides (the other ones are 1 and 0). return std::make_tuple(strideKeysVals, strideHeads, strideBatch); } diff --git a/tests/attention/test_trtllm_ragged_kv_stride.py b/tests/attention/test_trtllm_ragged_kv_stride.py new file mode 100644 index 0000000000..69c11359a7 --- /dev/null +++ b/tests/attention/test_trtllm_ragged_kv_stride.py @@ -0,0 +1,118 @@ +import pytest +import torch + +import flashinfer +from flashinfer.utils import get_compute_capability + + +@pytest.mark.cuda +def test_trtllm_ragged_kv_large_stride_overflow(): + """ + Test that ragged KV with large numel (>2^31) doesn't cause TMA descriptor error. + + Constructs a scenario where key.numel() = 131072 * 128 * 192 > 2^31, which + triggers int32 overflow in kStrideBatch. Before the fix, this caused negative + stride and TMA descriptor error. After the fix, negative strideBatch is clamped + to 0 for ragged layouts. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + if not hasattr(flashinfer.prefill, "trtllm_ragged_attention_deepseek"): + pytest.skip("trtllm_ragged_attention_deepseek is not available in this build") + + device = torch.device("cuda") + compute_capability = get_compute_capability(device) + if compute_capability[0] != 10: + pytest.skip( + f"TRTLLM-gen ragged attention requires SM100 and SM103 GPUs, got sm{compute_capability[0]}{compute_capability[1]}" + ) + + torch.manual_seed(42) + + # Configuration that triggers numel > 2^31 + batch_size = 16 + max_kv_len = 8192 + num_kv_heads = 128 + head_dim_qk = 192 + head_dim_vo = 128 + + # Construct ragged Q + seq_lens_q = torch.randint( + low=50, high=150, size=(batch_size,), device=device, dtype=torch.int32 + ) + cum_seq_lens_q = torch.cat( + [ + torch.zeros(1, device=device, dtype=torch.int32), + torch.cumsum(seq_lens_q, dim=0, dtype=torch.int32), + ], + dim=0, + ) + total_q = int(cum_seq_lens_q[-1].item()) + max_q_len = int(seq_lens_q.max().item()) + + q = torch.randn( + total_q, + num_kv_heads, + head_dim_qk, + device=device, + dtype=torch.bfloat16, + ) + + # Construct ragged KV: total_kv = 16 * 8192 = 131072 + # key.numel() = 131072 * 128 * 192 = 3,221,225,472 (0xC0000000) > 2^31 + seq_lens_kv = torch.full( + (batch_size,), max_kv_len, device=device, dtype=torch.int32 + ) + cum_seq_lens_kv = torch.arange( + 0, + (batch_size + 1) * max_kv_len, + max_kv_len, + device=device, + dtype=torch.int32, + ) + total_kv = int(cum_seq_lens_kv[-1].item()) + + k = torch.randn( + total_kv, + num_kv_heads, + head_dim_qk, + device=device, + dtype=torch.bfloat16, + ) + v = torch.randn( + total_kv, + num_kv_heads, + head_dim_vo, + device=device, + dtype=torch.bfloat16, + ) + + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=device) + scale = float(1.0 / (head_dim_qk**0.5)) + + # Should not raise "buildNdTmaDescriptor: invalid argument" error + output = flashinfer.prefill.trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=workspace_buffer, + seq_lens=seq_lens_kv, + max_q_len=max_q_len, + max_kv_len=max_kv_len, + bmm1_scale=scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=batch_size, + window_left=-1, + cum_seq_lens_q=cum_seq_lens_q, + cum_seq_lens_kv=cum_seq_lens_kv, + enable_pdl=False, + is_causal=True, + return_lse=False, + ) + + # Basic shape check + assert output.shape[0] == total_q + assert output.shape[1] == num_kv_heads + assert output.shape[2] == head_dim_vo From 1940b28e2e9de25f488614d39e07e20e5d4138de Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Wed, 26 Nov 2025 03:07:11 +0800 Subject: [PATCH 090/130] feat: add trtllm-gen per-tensor sparseMla kernels. (#2138) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This MR adds trtllm-gen per-tensor sparseMla kernels. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Added Sparse MLA mode to enable top-k sparse attention paths and configure sparse top-k behavior. * **Performance** * Improved kernel selection and runtime behavior to better support sparse MLA and varied head dimensions. * **Tests** * Expanded tests for multiple head dimensions and added comprehensive sparse MLA decoding tests and utilities. * **Validation** * Strengthened input/shape/runtime checks for sparse MLA configuration. * **Chores** * Updated public artifact references/checksums; tests now skip when insufficient GPUs are available. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Co-authored-by: Zihao Ye Co-authored-by: yzh119 --- csrc/fmhaReduction.cu | 23 +- csrc/trtllm_fmha_kernel_launcher.cu | 34 +- flashinfer/decode.py | 35 +- .../flashinfer/trtllm/fmha/fmhaKernels.cuh | 34 +- .../flashinfer/trtllm/fmha/fmhaRunnerParams.h | 4 + include/flashinfer/trtllm/fmha/kernelParams.h | 19 +- tests/attention/test_trtllm_gen_mla.py | 455 ++++++++++++++++++ 7 files changed, 566 insertions(+), 38 deletions(-) diff --git a/csrc/fmhaReduction.cu b/csrc/fmhaReduction.cu index 1f1ca8c755..e329e1c14b 100644 --- a/csrc/fmhaReduction.cu +++ b/csrc/fmhaReduction.cu @@ -34,7 +34,7 @@ namespace kernels { template __global__ void __launch_bounds__(NumThreadsPerCta, 2) - fmhaReductionKernel(KernelParams const params, int32_t numCtasForReduction, + fmhaReductionKernel(KernelParams const params, bool sparseMla, int32_t numCtasForReduction, int32_t numCtasForAllHeads, int32_t numHeadDimCtasV) { // clang-format off // The shape of partialO buffer: [batchSize, numHeadCtas, numCtasQ, numCtasKv, TileSizePerCtaQ, headDimPerCta]. @@ -64,10 +64,25 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2) // The number of validRows. int32_t const numValidRows{TileSizePerCtaQ}; + // The seqOffsetQ. + int32_t const seqOffsetQ{params.ptrCumSeqLensQ == nullptr ? batchIdx * params.mMaxSeqLenQ + : params.ptrCumSeqLensQ[batchIdx]}; + // The seqLenQ. + int32_t const seqLenQ{params.ptrCumSeqLensQ == nullptr + ? params.mMaxSeqLenQ + : (params.ptrCumSeqLensQ[batchIdx + 1] - seqOffsetQ)}; + // Early exit if ctaIdxQ >= seqLenQ, where each CTA processes one tokenQ. + if (ctaIdxQ >= seqLenQ) { + return; + } // The actual number of seqLenKv. int32_t seqLenKv{params.ptrSeqLensKv[batchIdx]}; // Consider the causal-mask speculative decoding. seqLenKv = seqLenKv - ((params.mMaxSeqLenQ - 1) - ctaIdxQ); + // Consider sparseMlaTopK. + if (sparseMla) { + seqLenKv = min(seqLenKv, params.mSparseMlaTopK); + } // The actual number of CtasKv (TileSizeKv is always 128 for now). int32_t numCtasKv{min((seqLenKv + 127) / 128, params.mMaxNumCtasKv)}; @@ -336,7 +351,7 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams config.numAttrs = 1; // Select the kernel function pointer. - void (*kernel)(KernelParams const, int32_t, int32_t, int32_t) = nullptr; + void (*kernel)(KernelParams const, bool, int32_t, int32_t, int32_t) = nullptr; if (headDimPerCtaV == 128) { SELECT_FMHA_REDUCTION_KERNEL(128); } else if (headDimPerCtaV == 256) { @@ -346,8 +361,8 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams } // Launch the kernel. - cudaLaunchKernelEx(&config, kernel, params, numCtasForReduction, numCtasForAllHeads, - numHeadDimCtasV); + cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseMla, numCtasForReduction, + numCtasForAllHeads, numHeadDimCtasV); cudaError_t err = cudaGetLastError(); FLASHINFER_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err)); } diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 5c1de17bb0..89fe53b874 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -82,8 +82,8 @@ void trtllm_paged_attention_launcher( int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, - int64_t window_left, int64_t sum_seq_q, int64_t sm_count, bool enable_pdl, - int64_t workspace_size, cudaStream_t stream) { + int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t sm_count, + bool enable_pdl, int64_t workspace_size, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -139,6 +139,12 @@ void trtllm_paged_attention_launcher( runner_params.ptrAttentionSinks = attention_sinks; runner_params.enable_pdl = enable_pdl; + // The sparse MLA parameters. + runner_params.mSparseMla = sparse_mla_top_k > 0; + runner_params.mSparseMlaTopK = sparse_mla_top_k; + TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || sparse_mla_top_k <= 0) + << "Only decode MLA supports sparse MLA"; + AlignedAllocator float_allocator(workspace_buffer, workspace_size); if (mode == TllmPagedAttentionMode::Context) { runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal; @@ -201,15 +207,13 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) { inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_TYPE_E2M1; } -void trtllm_paged_attention_decode(TensorView out, Optional out_scale_factor, - TensorView query, TensorView key_cache, TensorView value_cache, - TensorView workspace_buffer, TensorView block_tables, - TensorView seq_lens, int64_t max_kv_len, - Variant bmm1_scale, - Variant bmm2_scale, double o_sf_scale, - int64_t o_sf_vec_size, int64_t o_sf_start_index, - int64_t window_left, int64_t sm_count, bool enable_pdl, - int64_t workspace_size, Optional attention_sinks) { +void trtllm_paged_attention_decode( + TensorView out, Optional out_scale_factor, TensorView query, TensorView key_cache, + TensorView value_cache, TensorView workspace_buffer, TensorView block_tables, + TensorView seq_lens, int64_t max_kv_len, Variant bmm1_scale, + Variant bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, + int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, + bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -287,8 +291,8 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, - o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, - stream); + o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count, + enable_pdl, workspace_size, stream); } void trtllm_paged_attention_context( @@ -367,8 +371,8 @@ void trtllm_paged_attention_context( max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, - bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, - enable_pdl, workspace_size, stream); + bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, + /*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream); } void trtllm_ragged_attention_launcher( diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 1f682c9844..3f9f03ebb7 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1922,6 +1922,7 @@ def _paged_run( -1, # o_sf_vec_size 0, # o_sf_start_index window_left, + 0, # sparse_mla_top_k self._sm_count, enable_pdl, workspace_size, @@ -2328,6 +2329,7 @@ def trtllm_batch_decode_with_kv_cache( o_sf_vec_size or -1, o_sf_start_index, window_left, + 0, # sparse_mla_top_k sm_count, enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), @@ -2500,6 +2502,7 @@ def _check_trtllm_gen_mla_shape( qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, + sparse_mla_top_k, page_table, page_size, ): @@ -2524,16 +2527,23 @@ def _check_trtllm_gen_mla_shape( f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}" ) - B_block_table, block_num = page_table.shape - block_size = page_size - if B_q != B_block_table: - raise ValueError( - f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}" - ) - if block_num % (128 / block_size) != 0: - raise ValueError( - f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" - ) + if sparse_mla_top_k > 0: + page_table_shape = page_table.shape + if page_table_shape != (B_q, Q_len, sparse_mla_top_k): + raise ValueError( + f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}" + ) + else: + B_block_table, block_num = page_table.shape + block_size = page_size + if B_q != B_block_table: + raise ValueError( + f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}" + ) + if block_num % (128 / block_size) != 0: + raise ValueError( + f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" + ) @flashinfer_api @@ -2547,6 +2557,7 @@ def trtllm_batch_decode_with_kv_cache_mla( block_tables: torch.Tensor, seq_lens: torch.Tensor, max_seq_len: int, + sparse_mla_top_k: int = 0, out: Optional[torch.Tensor] = None, bmm1_scale: Union[float, torch.Tensor] = 1.0, bmm2_scale: Union[float, torch.Tensor] = 1.0, @@ -2562,6 +2573,7 @@ def trtllm_batch_decode_with_kv_cache_mla( qk_nope_head_dim: qk_nope_head_dim, must be 128 kv_lora_rank: kv_lora_rank, must be 512 qk_rope_head_dim: qk_rope_head_dim, must be 64 + sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA. block_tables: page_table of kv cache, [batch_size, num_pages] seq_lens: query_len max_seq_len: max sequence length for kv_cache @@ -2654,6 +2666,7 @@ def trtllm_batch_decode_with_kv_cache_mla( qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, + sparse_mla_top_k, block_tables, block_size, ) @@ -2687,6 +2700,7 @@ def trtllm_batch_decode_with_kv_cache_mla( -1, # o_sf_vec_size 0, # o_sf_start_index -1, # window_left + sparse_mla_top_k, sm_count, enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), @@ -2768,6 +2782,7 @@ def xqa_batch_decode_with_kv_cache_mla( qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, + 0, # sparse_mla_top_k block_tables, block_size, ) diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index 5bd91f4064..7fb695ed6d 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -333,6 +333,10 @@ class TllmGenFmhaKernel { if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { // The maximum attention window (the maximum number of tokensKv that will be attended to). int maxAttentionWindow{params.mMaxSeqLenKv}; + // The sparseMla only selects topK tokensKv. + if (params.mSparseMla) { + maxAttentionWindow = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK); + } // Some of the tilesKv will be skipped if the sliding window attention or chunked attention is // used. if (isSlidingOrChunkedCausalMask(selectKernelParams.mMaskType)) { @@ -365,7 +369,8 @@ class TllmGenFmhaKernel { // Need to select a different kernel. selectKernelParams.mSelectNewKernel = true; } else if (totalNumCtas < params.mMultiProcessorCount && isMlaGenKernel(params) && - selectKernelParams.mTileSizeKv == 128 && getEnvUseTileSizeKv64ForTrtllmGen()) { + !params.mSparseMla && selectKernelParams.mTileSizeKv == 128 && + getEnvUseTileSizeKv64ForTrtllmGen()) { // Use smaller tileSizeKv to fully utilize the SMs. selectKernelParams.mTileSizeKv = 64; // Need to select a different kernel. @@ -461,13 +466,15 @@ class TllmGenFmhaKernel { // We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the // following conditions are met: // 1. The number of headsQPerKv is <= 32. - // 2. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned + // 2. The number of headsQPerKv is < 128 for sparseMla. + // 3. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned // later) and // the numCtas (after splitting the heads across multiple CTAs) <= // params.mMultiProcessorCount. // Check the conditions. - if (params.mNumHeadsQPerKv <= 32 || useSwapsMmaAbMlaGenKernel(params)) { + if (params.mNumHeadsQPerKv <= 32 || (params.mSparseMla && params.mNumHeadsQPerKv < 128) || + useSwapsMmaAbMlaGenKernel(params)) { kernelType = FmhaKernelType::SwapsMmaAbForGeneration; } else { // Otherwise, we use the high-throughput kernel. @@ -476,6 +483,10 @@ class TllmGenFmhaKernel { if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; } + // The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128. + FLASHINFER_CHECK( + !params.mSparseMla || params.mNumHeadsQPerKv == 128, + "The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128"); // The 2CTA keepsMmaAbForGeneration kernel is used when the numHeadsQPerKv is 128. if (params.mNumHeadsQPerKv == 128) { selectKernelParams.mUses2CtaMma = true; @@ -524,8 +535,16 @@ class TllmGenFmhaKernel { "Sliding window attention and chunked attention should not be used together"); selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal; } - // NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels. - int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage; + + // The number of tokens per page. + int numTokensPerPage = params.mNumTokensPerPage; + // SparseMla kernels use a fixed numTokensPerPage = 1. + if (params.mSparseMla) { + numTokensPerPage = 1; + } else if (!isPagedKv(params.mQkvLayout)) { + // NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels. + numTokensPerPage = 0; + } // Debug info. std::string info = @@ -542,7 +561,8 @@ class TllmGenFmhaKernel { ", numTokensPerPage=" + std::to_string(numTokensPerPage) + ", maxNumHeadsQPerKvInCta=" + std::to_string(maxNumHeadsQPerKvInCta) + ", reuseSmemKForV=" + std::to_string(selectKernelParams.mReuseSmemKForV) + - ", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma); + ", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma) + + ", sparseMla=" + std::to_string(params.mSparseMla); IKL_LOG_DEBUG( "Searching for kernel traits (%d available) in TllmGenFmhaKernel(%s, %s, %s, %d) %s", getNumLoadedKernels(), toStr(mDtypeQ), toStr(mDtypeKv), toStr(mDtypeOut), mSM, @@ -555,7 +575,7 @@ class TllmGenFmhaKernel { selectKernelParams.mHeadDimPerCtaV, params.mHeadDimQk, params.mHeadDimV, selectKernelParams.mTileSizeKv, numTokensPerPage, maxNumHeadsQPerKvInCta, selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma, - /* sparseMla */ false), + params.mSparseMla), info); } diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index b05ce51ae3..ab48bc04cd 100755 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -287,6 +287,10 @@ struct TllmGenFmhaRunnerParams { float mScaleSfKv; // The SF scale for output. float mScaleSfO; + // Whether to use sparse MLA. + bool mSparseMla; + // The top k value for sparse MLA. + int mSparseMlaTopK; // The cuda stream. cudaStream_t stream; // Whether to enable PDL (Programmatic Dependent Launch). diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index bc45832968..6e62c05543 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -492,8 +492,8 @@ struct KernelParams { // Check shape must be in range [1, 2^32] int32_t dim = shapes.size(); - // Max five dimension and min 3 dimension. - FLASHINFER_CHECK((dim <= 5) && (dim >= 3)); + // Max five dimension and min 2 dimension. + FLASHINFER_CHECK((dim <= 5) && (dim >= 2)); // Check shape range. for (int32_t ii = 0; ii < dim; ++ii) { FLASHINFER_CHECK(shapes[ii] >= (uint64_t(1))); // Size must be min 1 @@ -603,6 +603,16 @@ struct KernelParams { std::vector tileShapeKv(shapeK.size(), 1); tileShapeKv[0] = numEltsInClampedHeadDimKv / numEltsDivisor; tileShapeKv[1] = numKeysPerTile; + + // If sparse MLA is enabled, the shape and stride for K need to be updated for 2D layout + // (numTokensKvInPagedKv, headDimQk). + if (options.mSparseMla) { + shapeK = std::vector{static_cast(options.mHeadDimQk), + static_cast(INT_MAX)}; + strideK = std::vector{1, static_cast(options.mHeadDimQk)}; + tileShapeKv[1] = 1; + } + // Build tma descriptor for K. params.tmaK_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeK, strideK, tileShapeKv, const_cast(kPtr), @@ -726,6 +736,11 @@ struct KernelParams { params.mStartTokenIdxSfO = options.mSfStartTokenIdx; params.mScaleSfKv = options.mScaleSfKv; params.ptrSoftmaxStats = options.softmaxStatsPtr; + // The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the + // indices. + FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0, + "SparseMlaTopK must be a multiple of 4"); + params.mSparseMlaTopK = options.mSparseMlaTopK; // TODO: Integrate trtllm block-sparse attention kernels when needed. params.mUseBlockSparseAttention = false; return params; diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index d56be03eb6..d71e8cb386 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -1,5 +1,6 @@ import pytest import torch +import random import flashinfer from flashinfer.utils import get_compute_capability @@ -9,6 +10,205 @@ workspace_size = 128 * 1024 * 1024 +def generate_sparse_indices( + batch_size: int, + q_len_per_request: int, + seq_lens: torch.Tensor, + topk: int, + page_size: int, + block_tables: torch.Tensor, + device: str, + seed: int = 42, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generate sparse attention indices for MLA. + + Returns: + abs_indices: [batch_size, q_len_per_request, topk] - absolute positions in sequence + indices_in_kvcache: [batch_size, q_len_per_request, topk] - positions in blocked KV cache + """ + random.seed(seed) + torch.manual_seed(seed) + + block_tables_cpu = block_tables.cpu() + seq_lens_cpu = seq_lens.cpu() + + abs_indices = torch.empty( + batch_size, q_len_per_request, topk, dtype=torch.int32, device="cpu" + ) + indices_in_kvcache = torch.empty( + batch_size, q_len_per_request, topk, dtype=torch.int32, device="cpu" + ) + + for i in range(batch_size): + cur_seq_len = int(seq_lens_cpu[i].item()) + # Generate indices for each query position + for j in range(q_len_per_request): + # Randomly sample topk positions from the sequence + if cur_seq_len > 0: + # cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk] + cur_abs_indices = torch.arange(0, topk, device="cpu") + # Convert to blocked indices + cur_blocked_indices = block_tables_cpu[ + i, cur_abs_indices // page_size + ] * page_size + (cur_abs_indices % page_size) + else: + cur_abs_indices = torch.empty(0, dtype=torch.int32, device="cpu") + cur_blocked_indices = torch.empty(0, dtype=torch.int32, device="cpu") + + # Pad with -1 if we don't have enough indices + if len(cur_abs_indices) < topk: + pad_len = topk - len(cur_abs_indices) + cur_abs_indices = torch.cat( + [ + cur_abs_indices, + torch.full((pad_len,), -1, device="cpu", dtype=torch.int32), + ] + ) + cur_blocked_indices = torch.cat( + [ + cur_blocked_indices, + torch.full((pad_len,), -1, device="cpu", dtype=torch.int32), + ] + ) + + # Randomly permute the indices + # perm = torch.randperm(topk, device="cpu") + perm = torch.arange(0, topk, device="cpu") + cur_abs_indices = cur_abs_indices[perm] + cur_blocked_indices = cur_blocked_indices[perm] + + abs_indices[i, j, :] = cur_abs_indices + indices_in_kvcache[i, j, :] = cur_blocked_indices + + return abs_indices.to(device), indices_in_kvcache.to(device) + + +def sparse_mla_reference_torch( + cache_seqlens: torch.Tensor, # [batch_size] + block_table: torch.Tensor, # [batch_size, ?] + q: torch.Tensor, # [batch_size, s_q, h_q, d] + blocked_k: torch.Tensor, # [?, block_size, d] + blocked_v: torch.Tensor, # [?, block_size, dv] + page_size: int, + is_causal: bool, + sm_scale: float, + indices: torch.Tensor | None = None, # [batch_size, s_q, topk] +) -> tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation in PyTorch for MLA attention. + Based on FlashMLA's reference implementation. + + Args: + cache_seqlens: Sequence lengths for each batch [batch_size] + block_table: Block table mapping [batch_size, max_num_blocks] + q: Query tensor [batch_size, s_q, h_q, d] + blocked_k: Blocked key cache [num_blocks, block_size, d] + blocked_v: Blocked value cache [num_blocks, block_size, dv] + page_size: Size of each block/page + is_causal: Whether to apply causal masking + sm_scale: Softmax scale factor + indices: Optional sparse indices [batch_size, s_q, topk] + + Returns: + output: Attention output [batch_size, s_q, h_q, dv] + lse: Log-sum-exp values [batch_size, h_q, s_q] + """ + + def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): + """Create attention mask for top-k sparse attention.""" + mask = torch.zeros(s_q, s_k, dtype=torch.bool) + for i in range(s_q): + cur_indices = indices[i] + valid_indices = cur_indices[cur_indices != -1] + mask[i, valid_indices] = True + return mask + + def scaled_dot_product_attention( + batch_idx: int, + query: torch.Tensor, # [h_q, s_q, d] + key: torch.Tensor, # [s_k, d] + value: torch.Tensor, # [s_k, dv] + is_causal: bool, + sm_scale: float, + indices: torch.Tensor | None, # [s_q, topk] + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot-product attention.""" + h_q = query.size(0) + s_q = query.shape[-2] + s_k = key.shape[-2] + dv = value.shape[-1] + + query = query.float() + key = key.float() + value = value.float() + + # Handle NaN values in KV + key[key != key] = 0.0 + value[value != value] = 0.0 + + # Compute attention weights: [h_q, s_q, s_k] + attn_weight = query @ key.transpose(-2, -1) + + # Apply masking if needed + if (is_causal and query.size(1) > 1) or indices is not None: + mask = torch.ones(s_q, s_k, dtype=torch.bool) + if is_causal: + mask = mask.tril(diagonal=s_k - s_q) + if indices is not None: + mask &= get_topk_attn_mask(s_q, s_k, indices) + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float, device=query.device) + mask = mask.to(device=query.device) + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + attn_weight += attn_bias.to(query.dtype) + + # Scale and softmax + attn_weight *= sm_scale + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + + # Compute output + output = attn_weight @ value # [h_q, s_q, dv] + + # Correct for query tokens which have no attendable keys + lonely_q_mask = lse == float("-inf") + output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output, lse + + b, s_q, h_q, d = q.size() + dv = blocked_v.size(2) + cache_seqlens_cpu = cache_seqlens.cpu() + + out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32) + + for i in range(b): + cur_len = int(cache_seqlens_cpu[i].item()) + cur_num_blocks = (cur_len + page_size - 1) // page_size + cur_block_indices = block_table[i][0:cur_num_blocks] + + # Gather KV for this sequence + cur_key = blocked_k[cur_block_indices].view(-1, d)[:cur_len, ...] + cur_value = blocked_v[cur_block_indices].view(-1, dv)[:cur_len, ...] + + cur_out, cur_lse = scaled_dot_product_attention( + i, + q[i].transpose(0, 1), # [h_q, s_q, d] + cur_key, # [s_k, d] + cur_value, # [s_k, dv] + is_causal, + sm_scale, + indices[i] if indices is not None else None, + ) + out_ref[i] = cur_out.transpose(0, 1) + lse_ref[i] = cur_lse + + out_ref = out_ref.to(torch.bfloat16).to(q.device) + return out_ref, lse_ref + + def trtllm_batch_decode_mla( batch_size: int, scale: float, @@ -296,3 +496,258 @@ def test_dsr1_trtllm_mla( backend, MAX_SEQ_LEN, ) + + +@pytest.mark.parametrize( + "batch_size", + [1, 2, 4, 16, 32, 64, 128], +) +@pytest.mark.parametrize("scale", [1.0]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("q_len_per_request", [1, 2]) +@pytest.mark.parametrize("topk", [128, 2048]) +@pytest.mark.parametrize("is_varlen", [False, True]) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("backend", ["trtllm-gen"]) +def test_trtllm_batch_decode_mla_sparse( + batch_size: int, + scale: float, + dtype: torch.dtype, + q_len_per_request: int, + topk: int, + is_varlen: bool, + enable_pdl: bool, + backend: str, +): + """ + Test sparse MLA decoding with top-k attention. + Based on FlashMLA test patterns from: + https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_decoding.py + """ + compute_capability = get_compute_capability(torch.device(device="cuda")) + if backend == "trtllm-gen": + if compute_capability[0] != 10: + pytest.skip("TRTLLM-GEN MLA only supports SM100 and SM103 GPUs") + + torch.manual_seed(42) + device = "cuda:0" + + # Deepseek attention config (decode-MLA) + num_q_heads = 128 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + kv_lora_rank = 512 + + # Fixed or variable sequence lengths + if is_varlen: + # Variable sequence lengths + MAX_SEQ_LEN = 4096 + seq_lens = [ + max( + topk, + int( + torch.distributions.Normal(MAX_SEQ_LEN, MAX_SEQ_LEN / 2) + .sample() + .item() + ), + ) + for _ in range(batch_size) + ] + seq_lens[-1] = MAX_SEQ_LEN # Ensure at least one max length + seq_lens = [min(s, MAX_SEQ_LEN) for s in seq_lens] + else: + # Fixed sequence length + MAX_SEQ_LEN = 4096 + seq_lens = [MAX_SEQ_LEN] * batch_size + + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) + + # Initialize query tensors + query = torch.randn( + batch_size, + q_len_per_request, + num_q_heads, + kv_lora_rank + qk_rope_head_dim, + device=device, + ) + query.clamp_(min=-1.0, max=1.0) + query = query.to(dtype) + + # Calculate blocks needed + page_size = 32 + blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size + max_num_blocks_per_seq = blocks_per_seq.max().item() + total_blocks_needed = int(blocks_per_seq.sum().item()) + + # Generate random but unique block IDs + all_block_ids = torch.randperm(total_blocks_needed, device=device) + + # Create block tables + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device + ) + block_id = 0 + for i in range(batch_size): + num_blocks_needed = int(blocks_per_seq[i].item()) + block_tables[i, :num_blocks_needed] = all_block_ids[ + block_id : block_id + num_blocks_needed + ] + block_id += num_blocks_needed + + # Create KV cache + num_blocks = total_blocks_needed + kv_cache = torch.randn( + size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), + device=device, + ) + kv_cache.clamp_(min=-1.0, max=1.0) + kv_cache = kv_cache.to(dtype) + + # Generate sparse indices + abs_indices, indices_in_kvcache = generate_sparse_indices( + batch_size, + q_len_per_request, + seq_lens_tensor, + topk, + page_size, + block_tables, + device, + ) + + # Mask unused KV cache entries with NaN for correctness checking + kv_cache_ref = kv_cache.clone() + if dtype == torch.float8_e4m3fn: + kv_cache_ref = kv_cache_ref.to(torch.bfloat16) + + # Mark all positions as NaN initially + all_indices = indices_in_kvcache.flatten().tolist() + all_indices = list(set(all_indices)) + if -1 in all_indices: + all_indices.remove(-1) + + # Only used indices should be valid + kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim) + used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu") + used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True + kv_cache_flat[~used_mask] = float("0") + + # Allocate workspace buffers + global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + workspace_size, dtype=torch.int8, device=device + ) + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( + workspace_size, dtype=torch.int8, device=device + ) + workspace_buffer = global_trtllm_gen_fmha_workspace_buffer + # workspace_buffer_ref = global_workspace_buffer + + # Run sparse decode-MLA + query_input = query.clone() + output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query_input, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=indices_in_kvcache, + seq_lens=seq_lens_tensor, + max_seq_len=max_seq_len, + sparse_mla_top_k=topk, + bmm1_scale=scale / ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5), + bmm2_scale=1.0, + enable_pdl=enable_pdl, + backend=backend, + ) + + # Check workspace buffer is zeroed + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + + # For now, just check that output has correct shape and no NaNs + expected_shape = (batch_size, q_len_per_request, num_q_heads, kv_lora_rank) + assert output.shape == expected_shape, ( + f"Output shape {output.shape} != {expected_shape}" + ) + + # Check for NaNs + if dtype != torch.float8_e4m3fn: + assert not torch.isnan(output).any(), "Output contains NaN values" + + # Generate reference output using PyTorch implementation + query_ref = query.clone() + if dtype == torch.float8_e4m3fn: + query_ref = query_ref.to(torch.bfloat16) + + # Split kv_cache into K and V components + # K uses full dimension (kv_lora_rank + qk_rope_head_dim) + # V uses only kv_lora_rank dimension + blocked_k = kv_cache_ref # [num_blocks, page_size, kv_lora_rank + qk_rope_head_dim] + blocked_v = kv_cache_ref[ + ..., :kv_lora_rank + ] # [num_blocks, page_size, kv_lora_rank] + + sm_scale = scale / ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) + + out_ref, lse_ref = sparse_mla_reference_torch( + cache_seqlens=seq_lens_tensor, + block_table=block_tables, + q=query_ref, + blocked_k=blocked_k, + blocked_v=blocked_v, + page_size=page_size, + is_causal=True, # Cover cases where number of attendable kv values are less than topk + sm_scale=sm_scale, + indices=abs_indices, + ) + + # Compare outputs + assert not torch.isnan(output).any(), "Kernel output contains NaN values" + assert not torch.isnan(out_ref).any(), "Reference output contains NaN values" + + if dtype == torch.float8_e4m3fn: + # FP8 has lower precision, use more relaxed tolerances + try: + torch.testing.assert_close( + output.float(), + out_ref.float(), + rtol=1e-1, + atol=1e-1, + ) + except AssertionError as e: + # Calculate element-wise differences for debugging + diff = torch.abs(output.float() - out_ref.float()) + max_diff = diff.max().item() + mean_diff = diff.mean().item() + print(f"Max difference: {max_diff}, Mean difference: {mean_diff}") + print(f"Output sample: {output[0, 0, 0, :8]}") + print(f"Reference sample: {out_ref[0, 0, 0, :8]}") + raise e + else: + # BF16 should have better precision + try: + torch.testing.assert_close( + output.float(), + out_ref.float(), + rtol=2e-2, + atol=8e-4, + ) + except AssertionError as e: + # Calculate element-wise differences for debugging + diff = torch.abs(output.float() - out_ref.float()) + max_diff = diff.max().item() + mean_diff = diff.mean().item() + print(f"Max difference: {max_diff}, Mean difference: {mean_diff}") + print(f"Output sample: {output[0, 0, 0, :8]}") + print(f"Output sample: {output[0, 1, 0, :8]}") + print(f"Reference sample: {out_ref[0, 0, 0, :8]}") + print(f"Reference sample: {out_ref[0, 1, 0, :8]}") + raise e + + print( + f"Sparse MLA test passed: batch_size={batch_size}, topk={topk}, " + f"q_len={q_len_per_request}, varlen={is_varlen}, dtype={dtype}" + ) From d0d99d219b536f492e3c9bcdaa00f2766463d351 Mon Sep 17 00:00:00 2001 From: juju812 Date: Wed, 26 Nov 2025 03:17:41 +0800 Subject: [PATCH 091/130] Use global TuningConfig, to fix memory leak caused by AutoTuner LRU cache and dynamic lambda TuningConfig (#2140) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This PR is to fix a memory leak bug caused by AutoTuner LRU cache and dynamic lambda TuningConfig ## ๐Ÿ” Related Issues https://github.com/flashinfer-ai/flashinfer/issues/2139 ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **Performance** * Reduced autotuner overhead by caching runner parameter names to avoid repeated signature inspection during profiling, speeding up tuning runs. * **New Features** * Centralized reusable tuning presets for mixed-precision GEMM (FP8/FP4) with additional tuning presets to improve autotuning and execution efficiency. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --------- Co-authored-by: He Jun Co-authored-by: yzh119 --- flashinfer/autotuner.py | 11 +++- flashinfer/gemm/gemm_base.py | 120 +++++++++++++++++++++-------------- 2 files changed, 82 insertions(+), 49 deletions(-) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 9f5fb67489..a81c8f2546 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -458,6 +458,13 @@ def choose_one( # Record the total configs to try self.stats.tuned_op_total_configs[custom_op] = len(profiles) + # Pre-compute runner arg names to avoid calling inspect.signature in the loop + runner_arg_names_map = {} + for r in runners: + runner_arg_names_map[r] = { + param.name for param in inspect.signature(r.forward).parameters.values() + } + for p in profiles: tensors = self._prepare_input_tensors(p, inputs) is_cache_hit, runner_id, tactic, _ = self.search_cache( @@ -470,9 +477,7 @@ def choose_one( for r_id, r in enumerate(runners): # TODO: use FakeTensor here. valid_tactics = r.get_valid_tactics(tensors, p) - runner_arg_names = { - p.name for p in inspect.signature(r.forward).parameters.values() - } + runner_arg_names = runner_arg_names_map[r] if "do_preparation" in runner_arg_names and len(valid_tactics) > 0: r(tensors, tactic=-1, do_preparation=True, **kwargs) for tac in valid_tactics: diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 251e2a4682..15b26f02ee 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -356,6 +356,25 @@ def forward( ) +_FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0,), # a_tensor_index + (-2,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + 4, # out_tensor_index + -2, + lambda shapes: shapes[0][-2], + ), + ), +) + + def fp8_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -376,29 +395,12 @@ def fp8_gemm_sm100( runners.append(_cudnn_gemm_fp8_runner()) assert runners, "No suitable runners found" tuner = AutoTuner.get() - a_tensor_index = 0 - out_tensor_index = 4 - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2] - ), - ), - ) inputs = [a, b, scale_a, scale_b, out, workspace_buffer] runner, tactic = tuner.choose_one( "fp8_gemm", runners, - tuning_config, + _FP8_GEMM_SM100_TUNING_CONFIG, inputs, ) @@ -2019,6 +2021,58 @@ def _heuristic_func_mm_fp4( return [c for c in candidate_backends if c in suitable_backends] +def _pad_up(x, y): + return ((x + y - 1) // y) * y + + +_MM_FP4_TUNING_CONFIG_8x4 = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0,), # a_tensor_index + (0,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + 2, # a_scale_tensor_index + 0, + lambda shapes: _pad_up(shapes[0][0], 8), + ), + ConstraintSpec( + 6, # out_tensor_index + 0, + lambda shapes: shapes[0][0], + ), + ), +) + + +_MM_FP4_TUNING_CONFIG_128x4 = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0,), # a_tensor_index + (0,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + 2, # a_scale_tensor_index + 0, + lambda shapes: _pad_up(shapes[0][0], 128), + ), + ConstraintSpec( + 6, # out_tensor_index + 0, + lambda shapes: shapes[0][0], + ), + ), +) + + @backend_requirement( { "cudnn": _cudnn_gemm_fp4_requirement, @@ -2138,34 +2192,8 @@ def mm_fp4( # Now we have a list of runners for desired & supported backends. tuner = AutoTuner.get() - a_tensor_index = 0 - a_scale_tensor_index = 2 - out_tensor_index = 6 - - def pad_up(x, y): - return ((x + y - 1) // y) * y - - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up( - shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 - ), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] - ), - ), + tuning_config = ( + _MM_FP4_TUNING_CONFIG_8x4 if use_8x4_sf_layout else _MM_FP4_TUNING_CONFIG_128x4 ) inputs = [ From 18004a89a77075605f291c6391fb82092b3cb619 Mon Sep 17 00:00:00 2001 From: Sukrit Date: Tue, 25 Nov 2025 19:32:21 -0500 Subject: [PATCH 092/130] feat: add seed offset args to sampler to allow cuda graph support (#2132) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description This PR adds optional seed/offset args to all the sampler functions to prevent calling the `get_seed_and_offset` function. If that function is not called, we can potentially make the sampler forward call as part of CUDAGraph and use that to replay it. We can directly compute the Seed/offset values, before launching the graph in a similar way to as it is being done in the current method and pass them when making the flashinfer call ## ๐Ÿ” Related Issues #978 : top_k_top_p_sampling_from_logits incompatible with torch.compile + CUDAGraph ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * Optional seed and offset parameters added to sampling APIs to enable deterministic RNG control while remaining optional. * **Tests** * New tests verify reproducible sampling when using the same seed/offset and variability when different values are used. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --- flashinfer/sampling.py | 125 +++++++++++++++++++++++++++++++---- tests/utils/test_sampling.py | 78 ++++++++++++++++++++++ 2 files changed, 189 insertions(+), 14 deletions(-) diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index a7f334a01a..9f80e3c926 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -88,6 +88,8 @@ def sampling_from_logits( indices: Optional[torch.Tensor], deterministic: bool, generator: Optional[torch.Generator], + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: device = logits.device # TODO: support more data types in logits to avoid conversion @@ -95,7 +97,8 @@ def sampling_from_logits( logits = logits.float() batch_size = indices.size(0) if indices is not None else logits.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) - seed, offset = get_seed_and_offset(batch_size * logits.size(1), generator) + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size * logits.size(1), generator) module.sampling_from_logits( logits, samples, @@ -124,12 +127,15 @@ def sampling_from_probs( indices: Optional[torch.Tensor], deterministic: bool, generator: Optional[torch.Generator], + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() batch_size = indices.size(0) if indices is not None else probs.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) - seed, offset = get_seed_and_offset(batch_size, generator) + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size, generator) module.sampling_from_probs( probs, samples, @@ -162,6 +168,8 @@ def top_p_sampling_from_probs( top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() @@ -170,7 +178,8 @@ def top_p_sampling_from_probs( ) batch_size = indices.size(0) if indices is not None else probs.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) - seed, offset = get_seed_and_offset(batch_size * 32, generator) + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size * 32, generator) module.top_p_sampling_from_probs( probs, samples, @@ -205,13 +214,16 @@ def top_k_sampling_from_probs( top_k_val: int, deterministic: bool, generator: Optional[torch.Generator], + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() batch_size = indices.size(0) if indices is not None else probs.size(0) maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None samples = torch.empty(batch_size, dtype=torch.int32, device=device) - seed, offset = get_seed_and_offset(batch_size * 32, generator) + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size * 32, generator) module.top_k_sampling_from_probs( probs, samples, @@ -247,6 +259,8 @@ def min_p_sampling_from_probs( min_p_val: float, deterministic: bool, generator: Optional[torch.Generator], + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() @@ -255,7 +269,8 @@ def min_p_sampling_from_probs( ) batch_size = indices.size(0) if indices is not None else probs.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) - seed, offset = get_seed_and_offset(batch_size, generator) + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size, generator) module.min_p_sampling_from_probs( probs, samples, @@ -280,6 +295,8 @@ def top_k_top_p_sampling_from_probs( top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: device = probs.device probs = probs.float() @@ -289,7 +306,8 @@ def top_k_top_p_sampling_from_probs( ) batch_size = indices.size(0) if indices is not None else probs.size(0) samples = torch.empty(batch_size, dtype=torch.int32, device=device) - seed, offset = get_seed_and_offset(batch_size * 32, generator) + if seed is None or offset is None: + seed, offset = get_seed_and_offset(batch_size * 32, generator) module.top_k_top_p_sampling_from_probs( probs, samples, @@ -419,6 +437,8 @@ def chain_speculative_sampling( output_emitted_draft_token_num: torch.Tensor, deterministic: bool, generator: Optional[torch.Generator], + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: device = draft_probs.device draft_probs = draft_probs.float() @@ -428,9 +448,10 @@ def chain_speculative_sampling( output_emitted_draft_token_num = output_emitted_draft_token_num.int() b, n = draft_token_ids.shape output_token_ids = torch.empty((b, n + 1), dtype=torch.int32, device=device) - seed, offset = get_seed_and_offset( - draft_probs.size(0) * (draft_probs.size(1) + 1), generator - ) + if seed is None or offset is None: + seed, offset = get_seed_and_offset( + draft_probs.size(0) * (draft_probs.size(1) + 1), generator + ) module.chain_speculative_sampling( draft_probs, draft_token_ids, @@ -571,6 +592,8 @@ def sampling_from_logits( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: r"""Fused GPU kernel for category sampling from logits. It's equivalent to sampling from :attr:`logits` after applying softmax. @@ -593,6 +616,10 @@ def sampling_from_logits( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`logits`, default is ``False``. + seed: Optional[int] + seed value to use for the rng during the sampling operation. + offset: Optional[int] + offset value to use for the rng during the sampling operation. Returns ------- samples: torch.Tensor @@ -620,7 +647,7 @@ def sampling_from_logits( raise ValueError("Input logits contains NaN.") _check_indices_dtype(indices) return get_sampling_module().sampling_from_logits( - logits, indices, deterministic, generator + logits, indices, deterministic, generator, seed, offset ) @@ -630,6 +657,8 @@ def sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: r"""Fused GPU kernel for category sampling from probabilities. @@ -651,6 +680,10 @@ def sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. + seed: Optional[int] + seed value to use for the rng during the sampling operation. + offset: Optional[int] + offset value to use for the rng during the sampling operation. Returns ------- @@ -685,7 +718,7 @@ def sampling_from_probs( raise ValueError("Input probs contains NaN.") _check_indices_dtype(indices) return get_sampling_module().sampling_from_probs( - probs, indices, deterministic, generator + probs, indices, deterministic, generator, seed, offset ) @@ -696,6 +729,8 @@ def top_p_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -726,6 +761,10 @@ def top_p_sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. + seed: Optional[int] + seed value to use for the rng during the sampling operation. + offset: Optional[int] + offset value to use for the rng during the sampling operation. Returns ------- @@ -769,7 +808,13 @@ def top_p_sampling_from_probs( _check_indices_dtype(indices) _check_tensor_param(top_p, probs) return get_sampling_module().top_p_sampling_from_probs( - probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator + probs, + indices, + *_to_tensor_scalar_tuple(top_p), + deterministic, + generator, + seed, + offset, ) @@ -780,6 +825,8 @@ def top_k_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: r"""Fused GPU kernel for top-k sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -810,6 +857,10 @@ def top_k_sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. + seed: Optional[int] + seed value to use for the rng during the sampling operation. + offset: Optional[int] + offset value to use for the rng during the sampling operation. Returns ------- @@ -853,7 +904,13 @@ def top_k_sampling_from_probs( _check_indices_dtype(indices) _check_tensor_param(top_k, probs) return get_sampling_module().top_k_sampling_from_probs( - probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator + probs, + indices, + *_to_tensor_scalar_tuple(top_k), + deterministic, + generator, + seed, + offset, ) @@ -864,6 +921,8 @@ def min_p_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: r"""Fused GPU kernel for `min_p sampling `_ from probabilities, @@ -895,6 +954,10 @@ def min_p_sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. + seed: Optional[int] + seed value to use for the rng during the sampling operation. + offset: Optional[int] + offset value to use for the rng during the sampling operation. Returns ------- @@ -933,7 +996,13 @@ def min_p_sampling_from_probs( _check_indices_dtype(indices) _check_tensor_param(min_p, probs) return get_sampling_module().min_p_sampling_from_probs( - probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator + probs, + indices, + *_to_tensor_scalar_tuple(min_p), + deterministic, + generator, + seed, + offset, ) @@ -946,6 +1015,8 @@ def top_k_top_p_sampling_from_logits( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: r"""Fused GPU kernel for top-k and top-p sampling from pre-softmax logits, @@ -985,6 +1056,10 @@ def top_k_top_p_sampling_from_logits( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. + seed: Optional[int] + seed value to use for the rng during the sampling operation. + offset: Optional[int] + offset value to use for the rng during the sampling operation. Returns ------- @@ -1042,6 +1117,8 @@ def top_k_top_p_sampling_from_logits( deterministic, check_nan=check_nan, generator=generator, + seed=seed, + offset=offset, ) elif filter_apply_order == "joint": probs = torch.softmax(logits, dim=-1) @@ -1055,6 +1132,8 @@ def top_k_top_p_sampling_from_logits( *_to_tensor_scalar_tuple(top_p), deterministic, generator, + seed, + offset, ) else: raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") @@ -1069,6 +1148,8 @@ def top_k_top_p_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: r"""Fused GPU kernel for top-k and top-p sampling from probabilities, @@ -1108,6 +1189,10 @@ def top_k_top_p_sampling_from_probs( A random number generator for the operation. check_nan: bool Whether to check nan in :attr:`probs`, default is ``False``. + seed: Optional[int] + seed value to use for the rng during the sampling operation. + offset: Optional[int] + offset value to use for the rng during the sampling operation. Returns ------- @@ -1159,6 +1244,8 @@ def top_k_top_p_sampling_from_probs( deterministic, check_nan=check_nan, generator=generator, + seed=seed, + offset=offset, ) elif filter_apply_order == "joint": if check_nan: @@ -1171,6 +1258,8 @@ def top_k_top_p_sampling_from_probs( *_to_tensor_scalar_tuple(top_p), deterministic, generator, + seed, + offset, ) else: raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") @@ -1372,6 +1461,8 @@ def chain_speculative_sampling( maybe_output_emitted_draft_token_num: Optional[torch.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, + seed: Optional[int] = None, + offset: Optional[int] = None, ) -> torch.Tensor: r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in paper `Accelerating Large Language Model Decoding with Speculative Sampling `_), @@ -1407,6 +1498,10 @@ def chain_speculative_sampling( Whether to use deterministic kernel implementation, default is ``True``. generator: Optional[torch.Generator] A random number generator for the operation. + seed: Optional[int] + seed value to use for the rng during the sampling operation. + offset: Optional[int] + offset value to use for the rng during the sampling operation. Returns ------- @@ -1473,5 +1568,7 @@ def chain_speculative_sampling( output_emitted_draft_token_num, deterministic, generator, + seed, + offset, ) return output_token_ids, output_accepted_token_num, output_emitted_draft_token_num diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index 4ce93914f4..7db5c78b74 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -726,6 +726,84 @@ def test_check_tensor_param_top_k(batch_size, vocab_size, k): assert samples.shape == normalized_prob.shape +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +def test_sampling_from_probs_seed_offset_reproducibility(batch_size, vocab_size): + """Test that explicit seed/offset produces reproducible results.""" + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + seed, offset = 12345, 0 + + samples1 = flashinfer.sampling.sampling_from_probs( + normalized_prob, seed=seed, offset=offset + ) + samples2 = flashinfer.sampling.sampling_from_probs( + normalized_prob, seed=seed, offset=offset + ) + + assert torch.all(samples1 == samples2), ( + "Same seed/offset should produce identical samples" + ) + + +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +def test_sampling_from_logits_seed_offset_reproducibility(batch_size, vocab_size): + """Test that explicit seed/offset produces reproducible results.""" + torch.manual_seed(42) + logits = torch.randn(batch_size, vocab_size, device="cuda:0") + + seed, offset = 12345, 0 + + samples1 = flashinfer.sampling.sampling_from_logits( + logits, seed=seed, offset=offset + ) + samples2 = flashinfer.sampling.sampling_from_logits( + logits, seed=seed, offset=offset + ) + + assert torch.all(samples1 == samples2), ( + "Same seed/offset should produce identical samples" + ) + + +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +def test_sampling_different_seed_offset_produces_different_results(vocab_size): + """Test that different seed/offset values produce different samples.""" + torch.manual_seed(42) + batch_size = 1000 + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + samples_seed1 = flashinfer.sampling.sampling_from_probs( + normalized_prob, seed=12345, offset=0 + ) + samples_seed2 = flashinfer.sampling.sampling_from_probs( + normalized_prob, seed=67890, offset=0 + ) + + samples_offset1 = flashinfer.sampling.sampling_from_probs( + normalized_prob, seed=12345, offset=0 + ) + samples_offset2 = flashinfer.sampling.sampling_from_probs( + normalized_prob, seed=12345, offset=1000 + ) + + seed_match_rate = (samples_seed1 == samples_seed2).float().mean().item() + offset_match_rate = (samples_offset1 == samples_offset2).float().mean().item() + + assert seed_match_rate < 1, ( + f"Different seeds should produce mostly different samples, " + f"got {seed_match_rate:.2%} match rate" + ) + assert offset_match_rate < 1, ( + f"Different offsets should produce mostly different samples, " + f"got {offset_match_rate:.2%} match rate" + ) + + if __name__ == "__main__": # test_sampling_freq(128256, gumbel_distribution(0.1), 0.5) test_sampling_from_logits_freq(128256, gumbel_distribution(0.1)) From df5c2e45ca4a873d3f0ced182584acf71826a00b Mon Sep 17 00:00:00 2001 From: kahyun <69875166+kahyunnam@users.noreply.github.com> Date: Thu, 27 Nov 2025 18:14:35 -0800 Subject: [PATCH 093/130] ci: Reduce test time by moving compilation off-line (#2089) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Download `flashinfer-cubin` and `flashinfer-jit-cache `to avoid compilation. (Unless the JIT kernel is not in the `flashinfer-jit-cache`; then it will still JIT compile during test runtime. We could set `export FLASHINFER_DISABLE_JIT = 1 `to avoid this, but then it will "skip" a lot of tests that use JIT kernels that aren't found in `flashinfer-jit-cache`.) ## ๐Ÿ” Related Issues Issue was discussed on slack. "Ideally, we would move that compilation off-line which would reduce test time & make kernel hang detection much easier. " ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Summary by CodeRabbit * **Chores** * Improved runtime install flow to detect CUDA, compute an effective JIT architecture mapping, and install matching precompiled kernel artifacts plus local package sources; these steps run only outside dry-run mode and verify installation by showing configuration. * Simplified build parallelism calculation to a constant division by 8 (with existing safety guards retained). * **Bug Fixes** * Missing precompiled kernel artifacts now cause an explicit error/abort instead of a warning. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --------- Co-authored-by: yzh119 --- scripts/build_flashinfer_jit_cache_whl.sh | 3 +- scripts/task_test_blackwell_kernels.sh | 68 +++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/scripts/build_flashinfer_jit_cache_whl.sh b/scripts/build_flashinfer_jit_cache_whl.sh index 4d00ae67f0..ad35fbf640 100755 --- a/scripts/build_flashinfer_jit_cache_whl.sh +++ b/scripts/build_flashinfer_jit_cache_whl.sh @@ -11,7 +11,8 @@ echo "==========================================" # MAX_JOBS = min(nproc, max(1, MemAvailable_GB/4)) MEM_AVAILABLE_GB=$(free -g | awk '/^Mem:/ {print $7}') NPROC=$(nproc) -MAX_JOBS=$(( MEM_AVAILABLE_GB / $([ "$(uname -m)" = "aarch64" ] && echo 8 || echo 4) )) +# MAX_JOBS=$(( MEM_AVAILABLE_GB / $([ "$(uname -m)" = "aarch64" ] && echo 8 || echo 4) )) +MAX_JOBS=$(( MEM_AVAILABLE_GB / 8 )) if (( MAX_JOBS < 1 )); then MAX_JOBS=1 elif (( NPROC < MAX_JOBS )); then diff --git a/scripts/task_test_blackwell_kernels.sh b/scripts/task_test_blackwell_kernels.sh index 312cf12eb1..0d7b0b1f4a 100644 --- a/scripts/task_test_blackwell_kernels.sh +++ b/scripts/task_test_blackwell_kernels.sh @@ -25,7 +25,75 @@ if [[ "$1" == "--dry-run" ]] || [[ "${DRY_RUN}" == "true" ]]; then fi if [ "$DRY_RUN" != "true" ]; then + echo "Using CUDA version: ${CUDA_VERSION}" + echo "" + + # Install precompiled kernels (require CI build artifacts) + JIT_ARCH_EFFECTIVE="" + # Map CUDA_VERSION to CUDA_STREAM for artifact lookup + if [[ "${CUDA_VERSION}" == cu* ]]; then + CUDA_STREAM="${CUDA_VERSION}" + elif [ "${CUDA_VERSION}" = "12.9.0" ]; then + CUDA_STREAM="cu129" + else + CUDA_STREAM="cu130" + fi + echo "Using CUDA stream: ${CUDA_STREAM}" + echo "" + if [ -n "${JIT_ARCH}" ]; then + # 12.0a for CUDA 12.9.0, 12.0f for CUDA 13.0.0 + if [ "${JIT_ARCH}" = "12.0" ]; then + if [ "${CUDA_STREAM}" = "cu129" ]; then + JIT_ARCH_EFFECTIVE="12.0a" + else + JIT_ARCH_EFFECTIVE="12.0f" + fi + else + JIT_ARCH_EFFECTIVE="${JIT_ARCH}" + fi + + echo "Using JIT_ARCH from environment: ${JIT_ARCH_EFFECTIVE}" + DIST_CUBIN_DIR="../dist/${CUDA_STREAM}/${JIT_ARCH_EFFECTIVE}/cubin" + DIST_JIT_CACHE_DIR="../dist/${CUDA_STREAM}/${JIT_ARCH_EFFECTIVE}/jit-cache" + + echo "==== Debug: listing artifact directories ====" + echo "Tree under ../dist:" + (cd .. && ls -al dist) || true + echo "" + echo "Tree under ../dist/${CUDA_STREAM}:" + (cd .. && ls -al "dist/${CUDA_STREAM}") || true + echo "" + echo "Contents of ${DIST_CUBIN_DIR}:" + ls -al "${DIST_CUBIN_DIR}" || true + echo "" + echo "Contents of ${DIST_JIT_CACHE_DIR}:" + ls -al "${DIST_JIT_CACHE_DIR}" || true + echo "=============================================" + + if [ -d "${DIST_CUBIN_DIR}" ] && ls "${DIST_CUBIN_DIR}"/*.whl >/dev/null 2>&1; then + echo "Installing flashinfer-cubin from ${DIST_CUBIN_DIR} ..." + pip install -q "${DIST_CUBIN_DIR}"/*.whl + else + echo "ERROR: flashinfer-cubin wheel not found in ${DIST_CUBIN_DIR}. Ensure the CI build stage produced the artifact." >&2 + fi + + if [ -d "${DIST_JIT_CACHE_DIR}" ] && ls "${DIST_JIT_CACHE_DIR}"/*.whl >/dev/null 2>&1; then + echo "Installing flashinfer-jit-cache from ${DIST_JIT_CACHE_DIR} ..." + pip install -q "${DIST_JIT_CACHE_DIR}"/*.whl + else + echo "ERROR: flashinfer-jit-cache wheel not found in ${DIST_JIT_CACHE_DIR} for ${CUDA_VERSION}. Ensure the CI build stage produced the artifact." >&2 + fi + echo "" + fi + + # Install local python sources pip install -e . -v --no-deps + echo "" + + # Verify installation + echo "Verifying installation..." + (cd /tmp && python -m flashinfer show-config) + echo "" fi EXIT_CODE=0 From b14408b20a19ef2d17c092c29384c98efc4e8a1f Mon Sep 17 00:00:00 2001 From: Jimmy Zhou <79552142+jimmyzho@users.noreply.github.com> Date: Fri, 28 Nov 2025 02:35:15 -0500 Subject: [PATCH 094/130] feat: TRTLLM FMHAv2 backend for ctx attention (#2142) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## ๐Ÿ“Œ Description Porting over the [trtllm fmhav2 library](https://github.com/NVIDIA/TensorRT-LLM/tree/main/cpp/kernels/fmha_v2) to support prefill cases. ## ๐Ÿ” Related Issues ## ๐Ÿš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### โœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## ๐Ÿงช Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes ## Summary by CodeRabbit * **New Features** * INT8 quantization and FP8 (E4M3/E5M2) conversion utilities, plus broad packed 8/16โ€‘bit output paths. * Hopper GMMA/TMA optimizations and SM90 GMMA/IGMMA helpers for highโ€‘performance kernels. * Extensive FMHA v2 tiling/load/store primitives (Q/K/V/O), TMA descriptor management, and paged KV cache. * **Enhanced Support** * Alibi positional-bias params, BF16/mixed-precision conversions, causal/sliding-window masks and multiโ€‘token prediction. โœ๏ธ Tip: You can customize this high-level summary in your review settings. --- csrc/fmha_v2/convert.cu | 196 + csrc/fmha_v2/fmha/alibi_params.h | 50 + csrc/fmha_v2/fmha/fragment.h | 2311 ++++++ csrc/fmha_v2/fmha/gemm.h | 35 + csrc/fmha_v2/fmha/gmem_tile_o.h | 465 ++ csrc/fmha_v2/fmha/gmem_tile_o_packed.h | 1349 ++++ csrc/fmha_v2/fmha/gmem_tile_ps.h | 837 ++ csrc/fmha_v2/fmha/gmem_tile_qkv.h | 167 + csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h | 1307 ++++ csrc/fmha_v2/fmha/hopper/arrive_wait.h | 396 + csrc/fmha_v2/fmha/hopper/compute_tile.h | 503 ++ csrc/fmha_v2/fmha/hopper/fragment.h | 491 ++ csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h | 1138 +++ .../fmha/hopper/gmem_tile_qkv_packed.h | 146 + csrc/fmha_v2/fmha/hopper/gmma_descriptor.h | 547 ++ csrc/fmha_v2/fmha/hopper/kernel_traits.h | 365 + csrc/fmha_v2/fmha/hopper/smem_tile.h | 2423 ++++++ csrc/fmha_v2/fmha/hopper/smem_tile_o.h | 325 + csrc/fmha_v2/fmha/hopper/tma_descriptor.h | 348 + csrc/fmha_v2/fmha/hopper/tma_types.h | 123 + csrc/fmha_v2/fmha/hopper/utils_gmma.h | 18 + csrc/fmha_v2/fmha/hopper/utils_hgmma.h | 874 +++ csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h | 475 ++ csrc/fmha_v2/fmha/hopper/utils_igmma.h | 396 + csrc/fmha_v2/fmha/hopper/utils_qgmma.h | 2089 +++++ csrc/fmha_v2/fmha/hopper/utils_tma.h | 155 + csrc/fmha_v2/fmha/hopper/utils_warpgroup.h | 44 + csrc/fmha_v2/fmha/kernel_traits.h | 879 +++ csrc/fmha_v2/fmha/mask.h | 785 ++ csrc/fmha_v2/fmha/numeric_types.h | 57 + csrc/fmha_v2/fmha/paged_kv_cache.h | 63 + csrc/fmha_v2/fmha/smem_tile.h | 2071 +++++ csrc/fmha_v2/fmha/smem_tile_o.h | 1646 ++++ csrc/fmha_v2/fmha/smem_tile_qkv.h | 592 ++ csrc/fmha_v2/fmha/smem_tile_v.h | 1008 +++ csrc/fmha_v2/fmha/softmax.h | 3964 ++++++++++ csrc/fmha_v2/fmha/traits.h | 942 +++ csrc/fmha_v2/fmha/utils.h | 2355 ++++++ csrc/fmha_v2/fmha/warpspec/circular_buffer.h | 399 + csrc/fmha_v2/fmha/warpspec/compute.h | 606 ++ csrc/fmha_v2/fmha/warpspec/dma.h | 874 +++ csrc/fmha_v2/fmha/warpspec/epilogue.h | 1091 +++ csrc/fmha_v2/fmha/warpspec/kernel_traits.h | 574 ++ csrc/fmha_v2/fused_multihead_attention.cpp | 1982 +++++ csrc/fmha_v2/fused_multihead_attention.h | 326 + ...sed_multihead_attention_demo_bert_params.h | 171 + .../fused_multihead_attention_kernel.h | 237 + .../fused_multihead_attention_kernel_1xN.h | 360 + ...multihead_attention_kernel_1xN_multi_cta.h | 465 ++ ...ed_multihead_attention_kernel_1xN_noloop.h | 316 + .../fused_multihead_attention_kernel_2x2.h | 286 + ...ed_multihead_attention_kernel_4x1_hopper.h | 742 ++ ...ihead_attention_kernel_4x1_hopper_noloop.h | 330 + ...ed_multihead_attention_kernel_4xN_hopper.h | 371 + ...ihead_attention_kernel_4xN_hopper_noloop.h | 382 + .../fmha_v2/fused_multihead_attention_utils.h | 1472 ++++ .../fused_multihead_cross_attention.cpp | 939 +++ .../fmha_v2/fused_multihead_cross_attention.h | 67 + ...sed_multihead_cross_attention_kernel_1xN.h | 361 + ...tihead_cross_attention_kernel_1xN_noloop.h | 319 + .../fused_multihead_flash_attention_kernel.h | 568 ++ ..._multihead_flash_attention_kernel_noloop.h | 645 ++ ...head_flash_attention_kernel_noloop_tiled.h | 577 ++ csrc/fmha_v2/softmax_bf16.cu | 21 + csrc/fmha_v2/softmax_fp16.cu | 21 + csrc/fmha_v2/softmax_fp32.cu | 21 + csrc/fmha_v2/softmax_fp8.cu | 22 + csrc/fmha_v2/softmax_impl.h | 1004 +++ csrc/fmha_v2/softmax_int8.cu | 22 + csrc/trtllm_fmha_v2_binding.cu | 430 + flashinfer/jit/__init__.py | 1 + flashinfer/jit/attention/__init__.py | 2 + .../jit/attention/fmha_v2/generate_kernels.py | 182 + .../jit/attention/fmha_v2/generator_utils.py | 6927 +++++++++++++++++ flashinfer/jit/attention/modules.py | 41 + flashinfer/prefill.py | 81 + .../test_fmha_v2_prefill_deepseek.py | 167 + 77 files changed, 55337 insertions(+) create mode 100644 csrc/fmha_v2/convert.cu create mode 100644 csrc/fmha_v2/fmha/alibi_params.h create mode 100644 csrc/fmha_v2/fmha/fragment.h create mode 100644 csrc/fmha_v2/fmha/gemm.h create mode 100644 csrc/fmha_v2/fmha/gmem_tile_o.h create mode 100644 csrc/fmha_v2/fmha/gmem_tile_o_packed.h create mode 100644 csrc/fmha_v2/fmha/gmem_tile_ps.h create mode 100644 csrc/fmha_v2/fmha/gmem_tile_qkv.h create mode 100644 csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h create mode 100644 csrc/fmha_v2/fmha/hopper/arrive_wait.h create mode 100644 csrc/fmha_v2/fmha/hopper/compute_tile.h create mode 100644 csrc/fmha_v2/fmha/hopper/fragment.h create mode 100644 csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h create mode 100644 csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h create mode 100644 csrc/fmha_v2/fmha/hopper/gmma_descriptor.h create mode 100644 csrc/fmha_v2/fmha/hopper/kernel_traits.h create mode 100644 csrc/fmha_v2/fmha/hopper/smem_tile.h create mode 100644 csrc/fmha_v2/fmha/hopper/smem_tile_o.h create mode 100644 csrc/fmha_v2/fmha/hopper/tma_descriptor.h create mode 100644 csrc/fmha_v2/fmha/hopper/tma_types.h create mode 100644 csrc/fmha_v2/fmha/hopper/utils_gmma.h create mode 100644 csrc/fmha_v2/fmha/hopper/utils_hgmma.h create mode 100644 csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h create mode 100644 csrc/fmha_v2/fmha/hopper/utils_igmma.h create mode 100644 csrc/fmha_v2/fmha/hopper/utils_qgmma.h create mode 100644 csrc/fmha_v2/fmha/hopper/utils_tma.h create mode 100644 csrc/fmha_v2/fmha/hopper/utils_warpgroup.h create mode 100644 csrc/fmha_v2/fmha/kernel_traits.h create mode 100644 csrc/fmha_v2/fmha/mask.h create mode 100644 csrc/fmha_v2/fmha/numeric_types.h create mode 100644 csrc/fmha_v2/fmha/paged_kv_cache.h create mode 100644 csrc/fmha_v2/fmha/smem_tile.h create mode 100644 csrc/fmha_v2/fmha/smem_tile_o.h create mode 100644 csrc/fmha_v2/fmha/smem_tile_qkv.h create mode 100644 csrc/fmha_v2/fmha/smem_tile_v.h create mode 100644 csrc/fmha_v2/fmha/softmax.h create mode 100644 csrc/fmha_v2/fmha/traits.h create mode 100644 csrc/fmha_v2/fmha/utils.h create mode 100644 csrc/fmha_v2/fmha/warpspec/circular_buffer.h create mode 100644 csrc/fmha_v2/fmha/warpspec/compute.h create mode 100644 csrc/fmha_v2/fmha/warpspec/dma.h create mode 100644 csrc/fmha_v2/fmha/warpspec/epilogue.h create mode 100644 csrc/fmha_v2/fmha/warpspec/kernel_traits.h create mode 100644 csrc/fmha_v2/fused_multihead_attention.cpp create mode 100644 csrc/fmha_v2/fused_multihead_attention.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_kernel.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_kernel_1xN.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_kernel_1xN_multi_cta.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_kernel_1xN_noloop.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_kernel_2x2.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_kernel_4x1_hopper.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_kernel_4x1_hopper_noloop.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_kernel_4xN_hopper.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_kernel_4xN_hopper_noloop.h create mode 100644 csrc/fmha_v2/fused_multihead_attention_utils.h create mode 100644 csrc/fmha_v2/fused_multihead_cross_attention.cpp create mode 100644 csrc/fmha_v2/fused_multihead_cross_attention.h create mode 100644 csrc/fmha_v2/fused_multihead_cross_attention_kernel_1xN.h create mode 100644 csrc/fmha_v2/fused_multihead_cross_attention_kernel_1xN_noloop.h create mode 100644 csrc/fmha_v2/fused_multihead_flash_attention_kernel.h create mode 100644 csrc/fmha_v2/fused_multihead_flash_attention_kernel_noloop.h create mode 100644 csrc/fmha_v2/fused_multihead_flash_attention_kernel_noloop_tiled.h create mode 100644 csrc/fmha_v2/softmax_bf16.cu create mode 100644 csrc/fmha_v2/softmax_fp16.cu create mode 100644 csrc/fmha_v2/softmax_fp32.cu create mode 100644 csrc/fmha_v2/softmax_fp8.cu create mode 100644 csrc/fmha_v2/softmax_impl.h create mode 100644 csrc/fmha_v2/softmax_int8.cu create mode 100644 csrc/trtllm_fmha_v2_binding.cu create mode 100644 flashinfer/jit/attention/fmha_v2/generate_kernels.py create mode 100755 flashinfer/jit/attention/fmha_v2/generator_utils.py create mode 100644 tests/attention/test_fmha_v2_prefill_deepseek.py diff --git a/csrc/fmha_v2/convert.cu b/csrc/fmha_v2/convert.cu new file mode 100644 index 0000000000..345bd008f9 --- /dev/null +++ b/csrc/fmha_v2/convert.cu @@ -0,0 +1,196 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +__global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) { + // The step. + size_t step = (size_t)gridDim.x * blockDim.x; + + // Iterate over the elements. + for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { + // Load 4 integers. + int4 tmp = reinterpret_cast(src)[ii]; + + // Convert to float and scale. + float x = static_cast(tmp.x) * scale; + float y = static_cast(tmp.y) * scale; + float z = static_cast(tmp.z) * scale; + float w = static_cast(tmp.w) * scale; + + // Convert to int8. + uint32_t a; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x)); + uint32_t b; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y)); + uint32_t c; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z)); + uint32_t d; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w)); + + // Compact. + char4 out; + out.x = reinterpret_cast(a); + out.y = reinterpret_cast(b); + out.z = reinterpret_cast(c); + out.w = reinterpret_cast(d); + + // Store. + reinterpret_cast(dst)[ii] = reinterpret_cast(out); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, + float scale) { + size_t n = (size_t)s * b * h * d; + convert_int32_to_int8_kernel<<<512, 256>>>(dst, src, n, scale); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline typename fmha::Uint_from_size_in_bytes::Type pack_float4( + float4 const& f); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +__device__ inline uint2 pack_float4(float4 const& f) { + return fmha::float4_to_half4(f.x, f.y, f.z, f.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +__device__ inline uint2 pack_float4(float4 const& f) { + return fmha::float4_to_16bit_x4(f.x, f.y, f.z, f.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +__device__ inline uint32_t pack_float4(float4 const& f) { + return fmha::float4_to_e4m3x4(f.x, f.y, f.z, f.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template <> +__device__ inline uint32_t pack_float4(float4 const& f) { + return fmha::float4_to_e5m2x4(f.x, f.y, f.z, f.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { + using Dst = typename fmha::Uint_from_size_in_bytes::Type; + + // The step. + size_t step = (size_t)gridDim.x * blockDim.x; + + // Iterate over the elements. + for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { + // Load 4 floats. + float4 tmp = reinterpret_cast(src)[ii]; + // Scale. + tmp.x *= scale; + tmp.y *= scale; + tmp.z *= scale; + tmp.w *= scale; + // Convert to 4 Ts. + auto out = pack_float4(tmp); + + // Store. + reinterpret_cast(dst)[ii] = reinterpret_cast(out); + } +} + +template +__global__ void convert_T_to_fp32_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { + using Src = typename fmha::Uint_from_size_in_bytes::Type; + + union { + Src raw; + T elt[4]; + } data; + + // The step. + size_t step = (size_t)gridDim.x * blockDim.x; + + // Iterate over the elements. + for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { + // Load 4 floats. + data.raw = reinterpret_cast(src)[ii]; + float4 out; + // Scale. + out.x = float(data.elt[0]) * scale; + out.y = float(data.elt[1]) * scale; + out.z = float(data.elt[2]) * scale; + out.w = float(data.elt[3]) * scale; + + // Store. + reinterpret_cast(dst)[ii] = reinterpret_cast(out); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_fp16(void* dst, void const* src, int s, int b, int h, int d) { + // No need to expose the scale factor for FP16/FP32. + size_t n = (size_t)s * b * h * d; + convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, 1.f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_bf16(void* dst, void const* src, int s, int b, int h, int d) { + // No need to expose the scale factor for FP16/FP32. + size_t n = (size_t)s * b * h * d; + convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, 1.f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_e4m3(void* dst, void const* src, size_t n, float scale_o) { + convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, scale_o); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_e4m3_to_fp32(void* dst, void const* src, size_t n, float scale_o) { + convert_T_to_fp32_kernel<<<512, 256>>>(dst, src, n, scale_o); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, + float scale_o) { + run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_fp32_to_e5m2(void* dst, void const* src, size_t n, float scale_o) { + convert_fp32_to_T_kernel<<<512, 256>>>(dst, src, n, scale_o); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_conversion_e5m2_to_fp32(void* dst, void const* src, size_t n, float scale_o) { + convert_T_to_fp32_kernel<<<512, 256>>>(dst, src, n, scale_o); +} diff --git a/csrc/fmha_v2/fmha/alibi_params.h b/csrc/fmha_v2/fmha/alibi_params.h new file mode 100644 index 0000000000..bee7ea1be9 --- /dev/null +++ b/csrc/fmha_v2/fmha/alibi_params.h @@ -0,0 +1,50 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +struct AlibiParams { + constexpr static int round_down_to_power_two(int x) { + x = x | (x >> 1); + x = x | (x >> 2); + x = x | (x >> 4); + x = x | (x >> 8); + x = x | (x >> 16); + return x - (x >> 1); + } + + AlibiParams() = default; + + AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) { + h_pow_2 = round_down_to_power_two(h); + alibi_neg4_div_h = -4.0f / h_pow_2; + } + + AlibiParams(int h, int s, int tp_size, int rank, float scale_after_alibi = 1.f) + : AlibiParams(h * tp_size, scale_after_alibi) { + head_idx_offset = h * rank; + sequence_pos_offset = s * rank; + } + + int h_pow_2{}; + float alibi_neg4_div_h{}; + float scale_after_alibi{}; + // Could be simplified to `int rank` derive the others as `num_heads * rank, s * rank` at + // runtime, but this makes assumptions about the layout downstream + // (e.g. downstream may only split across the head dimension, so s would be the full sequence) + int head_idx_offset = 0; + int sequence_pos_offset = 0; +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/fragment.h b/csrc/fmha_v2/fmha/fragment.h new file mode 100644 index 0000000000..01bdc0fdac --- /dev/null +++ b/csrc/fmha_v2/fmha/fragment.h @@ -0,0 +1,2311 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_ldg {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<1> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint8_t tmp; + fmha::ldg(tmp, ptr); + f.u8(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<2> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint16_t tmp; + fmha::ldg(tmp, ptr); + f.u16(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<4> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint32_t tmp; + fmha::ldg(tmp, ptr); + f.reg(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<8> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint2 tmp; + fmha::ldg(tmp, ptr); + f.reg(2 * ii + 0) = tmp.x; + f.reg(2 * ii + 1) = tmp.y; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_ldg<16> { + template + static inline __device__ void ldg(Fragment& f, int ii, void const* ptr) { + uint4 tmp; + fmha::ldg(tmp, ptr); + f.reg(4 * ii + 0) = tmp.x; + f.reg(4 * ii + 1) = tmp.y; + f.reg(4 * ii + 2) = tmp.z; + f.reg(4 * ii + 3) = tmp.w; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_lds {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_lds<2> { + template + static inline __device__ void lds(Fragment& f, int ii, uint32_t ptr) { + uint16_t tmp; + fmha::lds(tmp, ptr); + f.u16(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_lds<4> { + template + static inline __device__ void lds(Fragment& f, int ii, uint32_t ptr) { + uint32_t tmp; + fmha::lds(tmp, ptr); + f.reg(ii) = tmp; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_lds<8> { + template + static inline __device__ void lds(Fragment& f, int ii, uint32_t ptr) { + uint2 tmp; + fmha::lds(tmp, ptr); + f.reg(2 * ii + 0) = tmp.x; + f.reg(2 * ii + 1) = tmp.y; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_lds<16> { + template + static inline __device__ void lds(Fragment& f, int ii, uint32_t ptr) { + uint4 tmp; + fmha::lds(tmp, ptr); + f.reg(4 * ii + 0) = tmp.x; + f.reg(4 * ii + 1) = tmp.y; + f.reg(4 * ii + 2) = tmp.z; + f.reg(4 * ii + 3) = tmp.w; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// template<> +// struct Fragment_lds<32> { +// template< typename Fragment > +// static inline __device__ void lds(Fragment &f, int ii, uint32_t ptr) { +// uint4 tmp; +// fmha::lds(tmp, ptr); +// f.reg(8*ii+0) = tmp.x; +// f.reg(8*ii+1) = tmp.y; +// f.reg(8*ii+2) = tmp.z; +// f.reg(8*ii+3) = tmp.w; +// +// fmha::lds(tmp, static_cast(ptr)+sizeof(uint4)); +// f.reg(8*ii+4) = tmp.x; +// f.reg(8*ii+5) = tmp.y; +// f.reg(8*ii+6) = tmp.z; +// f.reg(8*ii+7) = tmp.w; +// } +// }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_stg {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<1> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + fmha::stg(ptr, f.u8(ii)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<2> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + fmha::stg(ptr, f.u16(ii)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<4> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + fmha::stg(ptr, f.reg(ii)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<8> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + uint2 tmp; + tmp.x = f.reg(2 * ii + 0); + tmp.y = f.reg(2 * ii + 1); + fmha::stg(ptr, tmp); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_stg<16> { + template + static inline __device__ void stg(void* ptr, Fragment const& f, int ii = 0) { + uint4 tmp; + tmp.x = f.reg(4 * ii + 0); + tmp.y = f.reg(4 * ii + 1); + tmp.z = f.reg(4 * ii + 2); + tmp.w = f.reg(4 * ii + 3); + fmha::stg(ptr, tmp); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_base_ { + // The data type. + using Data_type = Data_type_; + // default input type + using Input_type_ = Data_type_; + + // Does it store the array of elements. + enum { HAS_ELTS = BITS_PER_ELT_ >= 8 }; + + // The number of elements. + enum { NUM_ELTS = NUM_ELTS_ }; + + // The size of element in bits. + enum { BITS_PER_ELT = BITS_PER_ELT_ }; + + // The size of byte of a single register. + enum { BYTES_PER_REG = 4 }; + + // The size in bits. + enum { BITS_PER_REG = BYTES_PER_REG * 8 }; + + // The number of registers needed to store the fragment. + enum { NUM_REGS = Div_up::VALUE }; + + // The size in bytes (as returned by sizeof(Fragment_base<>). + enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG }; + + // The alignment. + enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The type of the elements. + typename Data_type_, + // The number of elements. + int NUM_ELTS_, + // The size of each element in bits. + int BITS_PER_ELT_, + // The alignment if you want to force a value -- use 0 otherwise. + int ALIGNMENT_, + // The base class. + typename Base_ = Fragment_base_> +struct alignas(static_cast(Base_::ALIGNMENT)) Fragment_base : public Base_ { + // The size of a load/store. + enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) }; + + // Clear the fragment. Using PTX in that code seems to produce better SASS... + inline __device__ void clear() { +#pragma unroll + for (int ii = 0; ii < Base_::NUM_REGS; ++ii) { + asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) :); + } + } + + // Load from global memory. + inline __device__ void ldg(void const* ptr) { + Fragment_ldg::ldg(*this, 0, ptr); + } + + // Load from shared memory. + inline __device__ void lds(uint32_t ptr) { + Fragment_lds::lds(*this, 0, ptr); + } + + // Immutable access to a register. + inline __device__ uint32_t const& reg(int ii) const { return this->regs_[ii]; } + + // Mutable access to a register. + inline __device__ uint32_t& reg(int ii) { return this->regs_[ii]; } + + // Set the fragment with a scalar + inline __device__ void set(uint32_t value) { +#pragma unroll + for (int ii = 0; ii < Base_::NUM_REGS; ++ii) { + this->reg(ii) = value; + } + } + + // Store to global memory. + inline __device__ void stg(void* ptr) const { + Fragment_stg::stg(ptr, *this, 0); + } + + // Immutable access to a byte. + inline __device__ uint8_t u8(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to a u8. + inline __device__ uint8_t& u8(int ii) { return reinterpret_cast(&this->regs_[0])[ii]; } + + // Immutable access to a half-word.. + inline __device__ uint16_t u16(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to a half-word. + inline __device__ uint16_t& u16(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Immutable access to a word. + inline __device__ uint32_t u32(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to a word. + inline __device__ uint32_t& u32(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Immutable access to a word. + inline __device__ uint2 u64(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to a word. + inline __device__ uint2& u64(int ii) { return reinterpret_cast(&this->regs_[0])[ii]; } + + // The storage in registers. + // + // NOTE: Instead of using only an array of uint32_t, we could use a union so we could either + // access the registers or the elements. We found that for: + // + // union { + // uint16_t elts_[4]; uint32_t regs_[2]; + // }; + // + // The compiler does not always produce a final structure of 8B. So, for the moment we are + // going to go only with the regs_ array and use reinterpret_cast<> to access elements (see + // below). It may be worth revisiting that when time permits. + uint32_t regs_[Base_::NUM_REGS]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment : public Fragment_base { + // Immutable access to the elements. + inline __device__ Data_type_ const& elt(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to the elements. + inline __device__ Data_type_& elt(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Immutable access to the elements with a cast. + template + inline __device__ Cast_type const& elt_as(int ii) const { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Mutable access to the elements. + template + inline __device__ Cast_type& elt_as(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + + // Add another fragment. + inline __device__ void add(Fragment const& other) { +#pragma unroll + for (int ii = 0; ii < NUM_ELTS_; ++ii) { + this->elt(ii) += other.elt(ii); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_b : public Fragment {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_accumulator {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The traits. + using Traits = Volta_hmma_fp16_traits; + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // HMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(0)), "+r"(this->reg(1)), "+r"(this->reg(2)), "+r"(this->reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(0)), "+r"(this->reg(1)), "+r"(this->reg(2)), "+r"(this->reg(3)) + : "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // HMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { + // K = 0..3 for threads 0..7 and 16..23 and K = 4..7 for 8..15 and 24..31. + asm volatile( + "mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(0)), "+r"(this->reg(1)), "+r"(this->reg(2)), "+r"(this->reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(4)), "+r"(this->reg(5)), "+r"(this->reg(6)), "+r"(this->reg(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(2)), "r"(b.reg(3))); + + // K = 8..11 for threads 0..7 and 16..23 and K = 12..15 for 8..15 and 24..31. + asm volatile( + "mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(0)), "+r"(this->reg(1)), "+r"(this->reg(2)), "+r"(this->reg(3)) + : "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(4)), "r"(b.reg(5))); + asm volatile( + "mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(this->reg(4)), "+r"(this->reg(5)), "+r"(this->reg(6)), "+r"(this->reg(7)) + : "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(6)), "r"(b.reg(7))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 \n" + " {%0, %1}, \n" + " {%2}, \n" + " {%3}, \n" + " {%0, %1}; \n" + : "+r"(this->reg(2 * i + 0)), "+r"(this->reg(2 * i + 1)) + : "r"(a.reg(i / 2)), "r"(b.reg(i % 2))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 \n" + " {%0, %1}, \n" + " {%2, %3}, \n" + " {%4}, \n" + " {%0, %1}; \n" + : "+r"(reg(0)), "+r"(reg(1)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(0))); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 \n" + " {%0, %1}, \n" + " {%2, %3}, \n" + " {%4}, \n" + " {%0, %1}; \n" + : "+r"(reg(2)), "+r"(reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(1))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + inline __device__ void mul(float const other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) *= other; + } + } + + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(0)), "+f"(elt(1)), "+f"(elt(2)), "+f"(elt(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(0))); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(4)), "+f"(elt(5)), "+f"(elt(6)), "+f"(elt(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(b.reg(1))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 \n" + " {%0, %1}, \n" + " {%2}, \n" + " {%3}, \n" + " {%0, %1}; \n" + : "+r"(this->reg(2 * i + 0)), "+r"(this->reg(2 * i + 1)) + : "r"(a.reg(i / 2)), "r"(b.reg(i % 2))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(0)), "+r"(reg(1)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(2)), "+r"(reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// BF16 MMA must accumulate with at least FP32 +template <> +struct Fragment_accumulator : public Fragment { + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(0)), "+r"(reg(1)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(2)), "+r"(reg(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + inline __device__ void mul(float const other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) *= other; + } + } + + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(0)), "+f"(elt(1)), "+f"(elt(2)), "+f"(elt(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(4)), "+f"(elt(5)), "+f"(elt(6)), "+f"(elt(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// BF16 MMA must accumulate with at least FP32 +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + inline __device__ void mul(float const other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) *= other; + } + } + + // Do the HMMA. + template + inline __device__ void mma(Fragment_a const& a, + Fragment_b const& b) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(0)), "+f"(elt(1)), "+f"(elt(2)), "+f"(elt(3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(0)), "r"(b.reg(1))); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(elt(4)), "+f"(elt(5)), "+f"(elt(6)), "+f"(elt(7)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(2)), "r"(b.reg(3))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 2; ++i) { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(reg(i * 4 + 0)), "+r"(reg(i * 4 + 1)), "+r"(reg(i * 4 + 2)), "+r"(reg(i * 4 + 3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(i * 2)), + "r"(b.reg(i * 2 + 1))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 2; ++i) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+r"(reg(i * 4 + 0)), "+r"(reg(i * 4 + 1)), "+r"(reg(i * 4 + 2)), "+r"(reg(i * 4 + 3)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(i * 2)), + "r"(b.reg(i * 2 + 1))); +#else + asm volatile("trap;\n"); +#endif + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_accumulator : public Fragment { + // The base class. + using Base = Fragment; + + // The fragments. + using Fragment_a = Fragment_a; + using Fragment_b = Fragment_b; + + // IMMA. + inline __device__ void mma(Fragment_a const& a, Fragment_b const& b) { +#pragma unroll + for (int i = 0; i < 2; ++i) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 \n" + " {%0, %1}, \n" + " {%2, %3, %4, %5}, \n" + " {%6, %7}, \n" + " {%0, %1}; \n" + : "+r"(reg(i * 2 + 0)), "+r"(reg(i * 2 + 1)) + : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)), "r"(b.reg(i * 2)), + "r"(b.reg(i * 2 + 1))); +#else + asm volatile("trap;\n"); +#endif + } + } +}; + +template +struct Tile_o_normalizer { + // The fragment accumulator. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 4 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Initialize the attention sinks. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : attention_sink_value_(params.attention_sinks != nullptr ? params.attention_sinks[binfo.bidh] + : -FLT_MAX) {} + + // Update the sum when attention sinks are used. + inline __device__ void update_sum(float const (&max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#pragma unroll + for (int i = 0; i < ROWS_PER_THREAD; ++i) { + sum[i] += expf(attention_sink_value_ - max[i]); + } + } + + // Update o. + inline __device__ void update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&curr_max)[ROWS_PER_THREAD], + float const (&prev_max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + uint32_t alpha[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + curr_max[jj] = fmax(prev_max[jj], curr_max[jj]); + float a = expf(prev_max[jj] - curr_max[jj]); + sum[jj] *= a; + // Convert back to FP16x2. + alpha[ii] = fmha::float2_to_half2(a, a); + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators in FP16x2. + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(alpha[ii & 1], acc_o_pair); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + curr_max[jj] = fmax(prev_max[jj], curr_max[jj]); + alpha[ii] = expf(prev_max[jj] - curr_max[jj]); + sum[jj] *= alpha[ii]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The registers. + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Do the math in Fp32. + acc_o_pair.x = alpha[ii & 1] * acc_o_pair.x; + acc_o_pair.y = alpha[ii & 1] * acc_o_pair.y; + + // Convert back to Fp16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Update o. + inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&sum)[ROWS_PER_THREAD]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + uint32_t beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + float b = (sum[jj] == 0.f || sum[jj] != sum[jj]) ? 1.f : 1.f / sum[jj]; + // Convert back to FP16x2. + beta[ii] = fmha::float2_to_half2(b, b); + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators in FP16x2. + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(acc_o_pair, beta[ii & 1]); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The diviser. + beta[ii] = (sum[jj] == 0.f || sum[jj] != sum[jj]) ? 1.f : 1.f / sum[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The registers. + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Do the math in Fp32. + acc_o_pair.x = acc_o_pair.x * beta[ii & 1]; + acc_o_pair.y = acc_o_pair.y * beta[ii & 1]; + + // Convert back to Fp16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Attention sink value. + float attention_sink_value_; +}; + +template +struct Tile_o_normalizer_fp32 { + // The fragment accumulator. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in the M dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + // The number of MMAs in the N dimension. + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread. + enum { REGS_PER_THREAD = 8 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Initialize the attention sinks. + template + inline __device__ Tile_o_normalizer_fp32(Params const& params, Block_info const& binfo) + : attention_sink_value_(params.attention_sinks != nullptr ? params.attention_sinks[binfo.bidh] + : -FLT_MAX) {} + + // Update the sum when attention sinks are used. + inline __device__ void update_sum(float const (&max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#pragma unroll + for (int i = 0; i < ROWS_PER_THREAD; ++i) { + sum[i] += expf(attention_sink_value_ - max[i]); + } + } + + // Update o. + inline __device__ void update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&curr_max)[ROWS_PER_THREAD], + float const (&prev_max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + curr_max[jj] = fmax(prev_max[jj], curr_max[jj]); + alpha[ii] = expf(prev_max[jj] - curr_max[jj]); + sum[jj] *= alpha[ii]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The register for O. + float acc_o_f = acc_o[mi][ni].elt(ii); + // Compute the next accumulator. + acc_o_f = alpha[(ii & 2) / 2] * acc_o_f; + // Update the accumulator. + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + // Update o after P * V + inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&sum)[ROWS_PER_THREAD]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + + // The diviser. + beta[ii] = (sum[jj] == 0.f || sum[jj] != sum[jj]) ? 1.f : 1.f / sum[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The register for O. + float acc_o_f = acc_o[mi][ni].elt(ii); + // Compute the next accumulator. + acc_o_f = acc_o_f * beta[(ii & 2) / 2]; + // Update the accumulator. + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + // Attention sink value. + float attention_sink_value_; +}; + +template +struct Tile_o_normalizer + : public Tile_o_normalizer_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Tile_o_normalizer_fp32; + + // The ctor. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} +}; + +template +struct Tile_o_normalizer + : public Tile_o_normalizer_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Tile_o_normalizer_fp32; + + // The ctor. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} +}; + +// The attention sinks are not enabled for Volta. +template +struct Tile_o_normalizer { + // The traits. + using Traits = Volta_hmma_fp16_16x16x16_traits; + + // The fragments. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 8 }; + + // Update o. + inline __device__ void update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float (&curr_max)[ROWS_PER_THREAD], + float const (&prev_max)[ROWS_PER_THREAD], + float (&sum)[ROWS_PER_THREAD]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors. + uint32_t alpha; + // Update the curr_max. + curr_max[mi] = fmax(prev_max[mi], curr_max[mi]); + // The multiplier. + float a = expf(prev_max[mi] - curr_max[mi]); + // The accumulated sum. + sum[mi] *= a; + // Convert back to FP16. + alpha = fmha::float2_to_half2(a, a); + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators packed in FP16x2. + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(acc_o_pair, alpha); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Update the curr_max. + curr_max[mi] = fmax(prev_max[mi], curr_max[mi]); + // The multiplier. + float alpha = expf(prev_max[mi] - curr_max[mi]); + // The accumulated sum. + sum[mi] *= alpha; + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators. Convert from FP16x2 to FP32x2. + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Apply the scaling. + acc_o_pair.x = alpha * acc_o_pair.x; + acc_o_pair.y = alpha * acc_o_pair.y; + + // Update the register after converting back to FP16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Update o. + inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + float const (&sum)[ROWS_PER_THREAD]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors. + uint32_t beta; + // The divisor. + float b = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + // Convert back to FP16. + beta = fmha::float2_to_half2(b, b); + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators packed in FP16x2. + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(acc_o_pair, beta); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The divisor. + float beta = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The registers. + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Do the math in Fp32. + acc_o_pair.x = acc_o_pair.x * beta; + acc_o_pair.y = acc_o_pair.y * beta; + + // Convert back to Fp16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } +}; + +template +struct Tile_o_normalizer + : public Tile_o_normalizer_fp32 { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Tile_o_normalizer_fp32; + + // The ctor. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Update the sum. + inline __device__ void update_sum(float const (&max)[Base::ROWS_PER_THREAD], + float (&sum)[Base::ROWS_PER_THREAD]) { +// Take the log2f(Traits::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to +// sum. +#pragma unroll + for (int i = 0; i < Base::ROWS_PER_THREAD; ++i) { + sum[i] += expf(this->attention_sink_value_ - max[i]) * Traits::SOFTMAX_FP_QUANT_SCALE; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Tile_o_normalizer + : public Tile_o_normalizer_fp32 { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Tile_o_normalizer_fp32; + + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in the M dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + // The number of MMAs in the N dimension. + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of registers per thread. + enum { REGS_PER_THREAD = 8 }; + + // The ctor. + template + inline __device__ Tile_o_normalizer(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + inline __device__ void merge(Fragment_accu (&acc_dst)[MMAS_M][MMAS_N], + Fragment_accu (&acc_src)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + acc_dst[mi][ni].elt(ii) += acc_src[mi][ni].elt(ii); + } + } + } + } + + template + inline __device__ void move_to_first_block(Params const& params, int bidb, int bidh) { + int scale_iter = bidb * params.h * params.sage.v.max_nblock + bidh * params.sage.v.max_nblock; + + params_scale_v_iter = reinterpret_cast(params.sage.v.scales + scale_iter); + params_scale_v_ = __ldg(params_scale_v_iter); + } + + inline __device__ void move_to_next_block() { + params_scale_v_iter += 1; + params_scale_v_ = __ldg(params_scale_v_iter); + } + + inline __device__ void apply_scale(Fragment_accu (&acc_o)[MMAS_M][MMAS_N]) { + float const scale = reinterpret_cast(params_scale_v_); + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + float acc_o_f = acc_o[mi][ni].elt(ii); + acc_o_f = scale * acc_o_f; + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + float const* params_scale_v_iter; + float params_scale_v_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_saver { + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 4 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Ctor. + template + inline __device__ Softmax_saver(Params const& params, Block_info const& binfo) + : actual_q_len_(binfo.actual_q_seqlen), + softmax_sum_ptr_(reinterpret_cast(params.softmax_stats_ptr)), + softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes) { + softmax_max_ptr_ = reinterpret_cast(params.softmax_stats_ptr); + + int warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + // MMA row0 index (8x4 thread layout) + + int m_per_mma = 32 / Mma_tile::THREADS_PER_MMA_N * 2; + row0_ = (warp % WARPS_M) * m_per_mma + (lane / 4); + // Decide whether to store the lse values + store_softmax_ = (lane % 4 == 0 && int(warp / WARPS_M) == 0); + + // assume fixed seq length for the batch + size_t const bh_offset = (binfo.sum_s * params.h + binfo.bidh) * sizeof(float) * 2; + softmax_max_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes; + softmax_sum_ptr_ += bh_offset + row0_ * params.softmax_stats_stride_in_bytes + sizeof(float); + }; + + inline __device__ void store(int q_loop, float* p_sum, float* p_max) { + if (store_softmax_) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + float sum0 = p_sum[mi * 2]; + float sum1 = p_sum[mi * 2 + 1]; + float max0 = p_max[mi * 2]; + float max1 = p_max[mi * 2 + 1]; + + int row_offset = q_loop * Cta_tile::M + mi * Mma_tile::M_PER_MMA_PER_CTA; + if (row0_ + row_offset < actual_q_len_) { + fmha::stg(softmax_max_ptr_ + row_offset * softmax_stats_stride_in_bytes_, max0); + fmha::stg(softmax_sum_ptr_ + row_offset * softmax_stats_stride_in_bytes_, sum0); + } + if (row0_ + row_offset + 8 < actual_q_len_) { + fmha::stg(softmax_max_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, max1); + fmha::stg(softmax_sum_ptr_ + (row_offset + 8) * softmax_stats_stride_in_bytes_, sum1); + } + } + } + } + + // ptr (total_token_q, h, 2) float + char* softmax_sum_ptr_ = nullptr; + char* softmax_max_ptr_ = nullptr; + + // the first row's idx + int row0_; + // actual seq length + int const actual_q_len_ = 0; + int const softmax_stats_stride_in_bytes_ = 0; + + // store lse or not + bool store_softmax_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Flash Attention: default applied to Turing, Ampere fp16 traits + +template +struct Fragment_updater { + // The fragment accumulator. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 4 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : actual_seqlen_(binfo.actual_seqlen), + softmax_lse_ptr_(reinterpret_cast(params.lse_ptr)) // [b, h, s] + { + int warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + // MMA row0 index (8x4 thread layout) + row0_ = (warp % WARPS_M) * Mma_tile::M_PER_MMA + (lane / 4); + // Decide whether to store the lse values + store_lse_ = (lane % 4 == 0 && int(warp / WARPS_M) == 0); + + // assume fixed seq length for the batch + size_t const bh_offset = + (binfo.bidb * params.h + binfo.bidh) * binfo.actual_seqlen * BYTES_PER_ELEMENT; + softmax_lse_ptr_ += bh_offset + row0_ * BYTES_PER_ELEMENT; + }; + + // init all statistics + inline __device__ Fragment_updater() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + curr_max_[row_i] = -HUGE_VALF; + prev_max_[row_i] = -HUGE_VALF; + prev_sum_[row_i] = 0.0f; + curr_sum_[row_i] = 0.0f; + } + } + + // Update o. + inline __device__ void update_o(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + Fragment_accu const (&local_acc_o)[MMAS_M][MMAS_N]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + uint32_t alpha[2], beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + float a = prev_sum_[jj] * __expf(prev_max_[jj] - curr_max_[jj]); + // The diviser. + float b = + (curr_sum_[jj] == 0.f || curr_sum_[jj] != curr_sum_[jj]) ? 1.f : 1.f / curr_sum_[jj]; + // Convert back to FP16x2. + alpha[ii] = fmha::float2_to_half2(a, a); + beta[ii] = fmha::float2_to_half2(b, b); + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators in FP16x2. + uint32_t local_o_pair = local_acc_o[mi][ni].reg(ii); + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hfma2(alpha[ii & 1], acc_o_pair, local_o_pair); + acc_o_pair = fmha::hmul2(acc_o_pair, beta[ii & 1]); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2], beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + alpha[ii] = prev_sum_[jj] * __expf(prev_max_[jj] - curr_max_[jj]); + // The diviser. + beta[ii] = + (curr_sum_[jj] == 0.f || curr_sum_[jj] != curr_sum_[jj]) ? 1.f : 1.f / curr_sum_[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The registers. + float2 local_o_pair = fmha::half2_to_float2(local_acc_o[mi][ni].reg(ii)); + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Do the math in Fp32. + acc_o_pair.x = (alpha[ii & 1] * acc_o_pair.x + local_o_pair.x) * beta[ii & 1]; + acc_o_pair.y = (alpha[ii & 1] * acc_o_pair.y + local_o_pair.y) * beta[ii & 1]; + + // Convert back to Fp16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Update max scale + inline __device__ void update_acc_max() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + float pre_curr_max_ = curr_max_[row_i]; + curr_max_[row_i] = fmaxf(prev_max_[row_i], curr_max_[row_i]); + prev_max_[row_i] = pre_curr_max_; + } + } + + // Update max scale + inline __device__ void update_acc_sum() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + float pre_curr_sum_ = curr_sum_[row_i]; + curr_sum_[row_i] = + __expf(prev_max_[row_i] - curr_max_[row_i]) * curr_sum_[row_i] + prev_sum_[row_i]; + prev_sum_[row_i] = pre_curr_sum_; + } + } + + inline __device__ void store(int q_loop) { + if (store_lse_) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + float row0_lse = curr_max_[mi * 2] + __logf(curr_sum_[mi * 2]); + float row1_lse = curr_max_[mi * 2 + 1] + __logf(curr_sum_[mi * 2 + 1]); + int row_offset = q_loop * Cta_tile::M + mi * Mma_tile::M_PER_MMA_PER_CTA; + if (row0_ + row_offset < actual_seqlen_) { + fmha::stg(softmax_lse_ptr_ + row_offset * BYTES_PER_ELEMENT, row0_lse); + } + if (row0_ + row_offset + 8 < actual_seqlen_) { + fmha::stg(softmax_lse_ptr_ + (row_offset + 8) * BYTES_PER_ELEMENT, row1_lse); + } + } + } + } + + // Update scales. + float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float curr_sum_[ROWS_PER_THREAD] = {0}; + float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + ; + float prev_sum_[ROWS_PER_THREAD] = {0}; + + // ptr + char* softmax_lse_ptr_ = nullptr; + + // the first row's idx + int row0_ = 0; + // actual seq length + int const actual_seqlen_ = 0; + + // store lse or not + bool store_lse_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Flash attention to update the accumulators in the 2nd GEMM when we accumulate in FP32. +// Support both hmma_fp32 and ampere_hmma_bf16 +template +struct Fragment_updater_ampere_fp32 { + // The fragment accumulator. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in the M dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + // The number of MMAs in the N dimension. + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = 2 * MMAS_M }; + + // The number of registers per thread. + enum { REGS_PER_THREAD = 8 }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // softmax data bytes + enum { BYTES_PER_ELEMENT = sizeof(float) }; + + // Ctor. + template + inline __device__ Fragment_updater_ampere_fp32(Params const& params, Block_info const& binfo) + : actual_seqlen_(binfo.actual_seqlen), + softmax_lse_ptr_(reinterpret_cast(params.lse_ptr)) // [b, h, s] + { + int warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + // MMA row0 index (8x4 thread layout) + row0_ = (warp % WARPS_M) * Mma_tile::M_PER_MMA + (lane / 4); + // Decide whether to store the lse values + store_lse_ = (lane % 4 == 0 && int(warp / WARPS_M) == 0); + + // assume fixed seq length for the batch + size_t const bh_offset = + (binfo.bidb * params.h + binfo.bidh) * binfo.actual_seqlen * BYTES_PER_ELEMENT; + softmax_lse_ptr_ += bh_offset + row0_ * BYTES_PER_ELEMENT; + }; + + // init all statistics + inline __device__ Fragment_updater_ampere_fp32() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + curr_max_[row_i] = -HUGE_VALF; + prev_max_[row_i] = -HUGE_VALF; + prev_sum_[row_i] = 0.0f; + curr_sum_[row_i] = 0.0f; + } + } + + // Update o after P * V + inline __device__ void update_o(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + Fragment_accu const (&local_acc_o)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2], beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + alpha[ii] = prev_sum_[jj] * __expf(prev_max_[jj] - curr_max_[jj]); + // The diviser. + beta[ii] = + (curr_sum_[jj] == 0.f || curr_sum_[jj] != curr_sum_[jj]) ? 1.f : 1.f / curr_sum_[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The register from P. + float local_acc_o_f = local_acc_o[mi][ni].elt(ii); + // The register for O. + float acc_o_f = acc_o[mi][ni].elt(ii); + // Compute the next accumulator. + acc_o_f = (alpha[(ii & 2) / 2] * acc_o_f + local_acc_o_f) * beta[(ii & 2) / 2]; + // Update the accumulator. + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + // Update o before P * V + inline __device__ void update_o(Fragment_accu (&acc_o)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors for the 2 rows. + float alpha[2], beta[2]; +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The row. + int jj = 2 * mi + ii; + // The multiplier. + alpha[ii] = prev_sum_[jj] * __expf(prev_max_[jj] - curr_max_[jj]); + // The diviser. + beta[ii] = + (curr_sum_[jj] == 0.f || curr_sum_[jj] != curr_sum_[jj]) ? 1.f : 1.f / curr_sum_[jj]; + } + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The register for O. + float acc_o_f = acc_o[mi][ni].elt(ii); + // Compute the next accumulator. + acc_o_f = alpha[(ii & 2) / 2] * acc_o_f * beta[(ii & 2) / 2]; + // Update the accumulator. + acc_o[mi][ni].elt(ii) = acc_o_f; + } + } + } + } + + // Update max scale + inline __device__ void update_acc_max() { +#pragma unroll + for (int ii = 0; ii < ROWS_PER_THREAD; ++ii) { + float curr_max = curr_max_[ii]; + curr_max_[ii] = fmaxf(prev_max_[ii], curr_max); + prev_max_[ii] = curr_max; + } + } + + // Update max scale + inline __device__ void update_acc_sum() { +#pragma unroll + for (int ii = 0; ii < ROWS_PER_THREAD; ++ii) { + float curr_sum = curr_sum_[ii]; + curr_sum_[ii] = __expf(prev_max_[ii] - curr_max_[ii]) * curr_sum_[ii] + prev_sum_[ii]; + prev_sum_[ii] = curr_sum; + } + } + + inline __device__ void store(int q_loop) { + if (store_lse_) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + float row0_lse = curr_max_[mi * 2] + __logf(curr_sum_[mi * 2]); + float row1_lse = curr_max_[mi * 2 + 1] + __logf(curr_sum_[mi * 2 + 1]); + int row_offset = q_loop * Cta_tile::M + mi * Mma_tile::M_PER_MMA_PER_CTA; + if (row0_ + row_offset < actual_seqlen_) { + fmha::stg(softmax_lse_ptr_ + row_offset * BYTES_PER_ELEMENT, row0_lse); + } + if (row0_ + row_offset + 8 < actual_seqlen_) { + fmha::stg(softmax_lse_ptr_ + (row_offset + 8) * BYTES_PER_ELEMENT, row1_lse); + } + } + } + } + + // Update scales. + float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float curr_sum_[ROWS_PER_THREAD] = {0}; + float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float prev_sum_[ROWS_PER_THREAD] = {0}; + + // ptr + char* softmax_lse_ptr_ = nullptr; + + // the first row's idx + int row0_ = 0; + // actual seq length + int const actual_seqlen_ = 0; + + // store lse or not + bool store_lse_ = false; +}; + +template +struct Fragment_updater + : public Fragment_updater_ampere_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Fragment_updater_ampere_fp32; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Default ctor + Fragment_updater() = default; +}; + +template +struct Fragment_updater + : public Fragment_updater_ampere_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Fragment_updater_ampere_fp32; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Default ctor + Fragment_updater() = default; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_updater + : public Fragment_updater_ampere_fp32 { + // The traits. + using Traits = fmha::Turing_hmma_fp32_traits; + // The base class. + using Base = Fragment_updater_ampere_fp32; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Default ctor + Fragment_updater() = default; +}; + +template +struct Fragment_updater + : public Fragment_updater_ampere_fp32 { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Fragment_updater_ampere_fp32; + + // Ctor. + template + inline __device__ Fragment_updater(Params const& params, Block_info const& binfo) + : Base(params, binfo) {} + + // Default ctor + Fragment_updater() = default; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_updater { + // The traits. + using Traits = Volta_hmma_fp16_16x16x16_traits; + + // The fragments. + using Fragment_accu = Fragment_accumulator; + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::VALID_MMAS_N }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = MMAS_M }; + + // The number of registers per thread + enum { REGS_PER_THREAD = 8 }; + + // init all statistics + inline __device__ Fragment_updater() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + curr_max_[row_i] = -HUGE_VALF; + prev_max_[row_i] = -HUGE_VALF; + prev_sum_[row_i] = 0.0f; + curr_sum_[row_i] = 0.0f; + } + } + + // Update o. + inline __device__ void update_o(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], + Fragment_accu const (&local_acc_o)[MMAS_M][MMAS_N]) { +#ifdef HALF_ACCUMULATION_FOR_FLASH_ATTENTION // Half accumulation + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // Precompute the scaling factors. + uint32_t alpha, beta; + // The multiplier. + float a = prev_sum_[mi] * __expf(prev_max_[mi] - curr_max_[mi]); + // The diviser. + float b = + (curr_sum_[mi] == 0.f || curr_sum_[mi] != curr_sum_[mi]) ? 1.f : 1.f / curr_sum_[mi]; + // Convert back to FP16. + alpha = fmha::float2_to_half2(a, a); + beta = fmha::float2_to_half2(b, b); + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators packed in FP16x2. + uint32_t local_o_pair = local_acc_o[mi][ni].reg(ii); + uint32_t acc_o_pair = acc_o[mi][ni].reg(ii); + + // Apply the scaling. + acc_o_pair = fmha::hmul2(fmha::hfma2(alpha, acc_o_pair, local_o_pair), beta); + + // Update the register. + acc_o[mi][ni].reg(ii) = acc_o_pair; + } + } + } +#else // Float accumulation +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The multiplier. + float alpha = prev_sum_[mi] * __expf(prev_max_[mi] - curr_max_[mi]); + // The diviser. + float beta = + (curr_sum_[mi] == 0.f || curr_sum_[mi] != curr_sum_[mi]) ? 1.f : 1.f / curr_sum_[mi]; + +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < REGS_PER_THREAD; ++ii) { + // The accumulators. Convert from FP16x2 to FP32x2. + float2 local_o_pair = fmha::half2_to_float2(local_acc_o[mi][ni].reg(ii)); + float2 acc_o_pair = fmha::half2_to_float2(acc_o[mi][ni].reg(ii)); + + // Apply the scaling. + acc_o_pair.x = (alpha * acc_o_pair.x + local_o_pair.x) * beta; + acc_o_pair.y = (alpha * acc_o_pair.y + local_o_pair.y) * beta; + + // Update the register after converting back to FP16x2. + acc_o[mi][ni].reg(ii) = fmha::float2_to_half2(acc_o_pair); + } + } + } +#endif // defined HALF_ACCUMULATION_FOR_FLASH_ATTENTION + } + + // Update max scale + inline __device__ void update_acc_max() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + float pre_curr_max_ = curr_max_[row_i]; + curr_max_[row_i] = fmaxf(prev_max_[row_i], curr_max_[row_i]); + prev_max_[row_i] = pre_curr_max_; + } + } + + // Update max scale + inline __device__ void update_acc_sum() { +#pragma unroll + for (int row_i = 0; row_i < ROWS_PER_THREAD; ++row_i) { + float pre_curr_sum_ = curr_sum_[row_i]; + curr_sum_[row_i] = + __expf(prev_max_[row_i] - curr_max_[row_i]) * curr_sum_[row_i] + prev_sum_[row_i]; + prev_sum_[row_i] = pre_curr_sum_; + } + } + + // updater scales + float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float curr_sum_[ROWS_PER_THREAD] = {0}; + float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; + float prev_sum_[ROWS_PER_THREAD] = {0}; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_from_size_in_bytes { + using Type = Fragment(sizeof(Data_type_))>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_from_size_in_bytes { + using Type = Fragment; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void clear(Fragment (&frag)[M][N]) { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + frag[mi][ni].clear(); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator { + template + static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { + fmha::clear(acc); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator { + template + static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { + fmha::clear(acc); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator { + template + static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { + fmha::clear(acc); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Clear_accumulator { + template + static inline __device__ void apply(Acc (&acc)[M][N], bool enable_i2f_trick = true) { +#if defined(USE_I2F_EMULATION_TRICK) + if (enable_i2f_trick) { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { +#pragma unroll + for (int ii = 0; ii < Acc::NUM_REGS; ++ii) { + acc[mi][ni].reg(ii) = uint32_t(FP32_I2F_MAGIC_NUMBER_HEX) / WARPS_K; + } + } + } + } else +#endif // defined(USE_I2F_EMULATION_TRICK) + { + fmha::clear(acc); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gemm.h b/csrc/fmha_v2/fmha/gemm.h new file mode 100644 index 0000000000..e1422e4f6e --- /dev/null +++ b/csrc/fmha_v2/fmha/gemm.h @@ -0,0 +1,35 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Acc (&acc)[M][N], A const (&a)[M], B const (&b)[N]) { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + acc[mi][ni].mma(a[mi], b[ni]); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_o.h b/csrc/fmha_v2/fmha/gmem_tile_o.h new file mode 100644 index 0000000000..c3177dc219 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_o.h @@ -0,0 +1,465 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +namespace fmha { +namespace v1 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hmma_gmem_tile_o { + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = 16 }; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + // END OF DEBUG. + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + + // Ctor. + template + inline __device__ Hmma_gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + o_ptr_(reinterpret_cast(params.o_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is that thread active on the last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Account for the CTA-wide row offset (no loop mode). + row += cta_row_offset; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row * params.o_stride_in_bytes; + // Take the batch/head offset into account. + row_offset += (int64_t)binfo.bidx * BYTES_PER_ROW; + // Assemble the final pointer. + o_ptr_ += row_offset + col * BYTES_PER_STG; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + if (blockIdx.x == 0) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + dst[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } else { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_)) { + fmha::ldg(dst[ii], o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_); + } + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_)) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, src[ii]); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + uint4 tmp[STGS_PER_LOOP]; +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + tmp[ii].x = fmha::hadd2(src[ii].x, old[ii].x); + tmp[ii].y = fmha::hadd2(src[ii].y, old[ii].y); + tmp[ii].z = fmha::hadd2(src[ii].z, old[ii].z); + tmp[ii].w = fmha::hadd2(src[ii].w, old[ii].w); + } + this->store(tmp, mi); + } + + // Move the pointer to the next location. + inline __device__ void move() { o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_o_stride_in_bytes_; + // The pointer. + char* o_ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Volta_hmma_fp16_16x16x16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset = 0) + : Base(params, binfo, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Turing_hmma_fp16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset = 0) + : Base(params, binfo, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_fp16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset = 0) + : Base(params, binfo, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// I M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Imma_gmem_tile_o { + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = 1 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = 4 }; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + // END OF DEBUG. + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads (last STG). + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + + // Are all threads active? + enum { ALL_THREADS_ACTIVE = ROWS_PER_STG <= ROWS_PER_LOOP }; + + // The number of active threads. + enum { ACTIVE_THREADS_ = Cta_tile::THREADS_PER_CTA * ROWS_PER_LOOP / ROWS_PER_STG }; + + // The number of active threads. + enum { ACTIVE_THREADS = ALL_THREADS_ACTIVE ? Cta_tile::THREADS_PER_CTA : ACTIVE_THREADS_ }; + + // Ctor. + template + inline __device__ Imma_gmem_tile_o(Params const& params, int bidx, int tidx, int cta_row_offset) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + params_scale_bmm2_(params.scale_bmm2), + params_enable_i2f_trick_(params.enable_i2f_trick), + o_ptr_(reinterpret_cast(params.o_ptr)) +#if USE_DEMO_BERT_PARAMS + , + o_scratch_ptr_(nullptr) { +#else + , + o_scratch_ptr_(reinterpret_cast(params.o_scratch_ptr)) { +#endif + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is it an active thread? + is_active_ = ALL_THREADS_ACTIVE || row < ROWS_PER_LOOP; + + // Is that thread active on the last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Update the row. + row += cta_row_offset; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row * params.o_stride_in_bytes; + // Take the batch/head offset into account. + row_offset += (int64_t)bidx * BYTES_PER_ROW; + // Assemble the final pointers. + o_ptr_ += row_offset + col * BYTES_PER_STG; + + // For the scratch space, the pointer has int32 type so it accounts for the *4 factor. + o_scratch_ptr_ += blockIdx.y * STGS_PER_LOOP * ACTIVE_THREADS + tidx; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + if (blockIdx.x == 0) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + dst[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } else if (ALL_THREADS_ACTIVE || is_active_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + fmha::ldg(dst[ii], o_scratch_ptr_ + ii * ACTIVE_THREADS); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { + // The scale. + float const& scale = reinterpret_cast(params_scale_bmm2_); +// Iterate over the different STGs. +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + // The accumulators are in int32_t. + int4 const& val = reinterpret_cast(src[ii]); + + // Extract the floats and scale. + float f0, f1, f2, f3; +#if defined(USE_I2F_EMULATION_TRICK) + if (params_enable_i2f_trick_) { + f0 = reinterpret_cast(val.x) - FP32_I2F_MAGIC_NUMBER; + f1 = reinterpret_cast(val.y) - FP32_I2F_MAGIC_NUMBER; + f2 = reinterpret_cast(val.z) - FP32_I2F_MAGIC_NUMBER; + f3 = reinterpret_cast(val.w) - FP32_I2F_MAGIC_NUMBER; + } else +#endif // defined(USE_I2F_EMULATION_TRICK) + { + f0 = static_cast(val.x); + f1 = static_cast(val.y); + f2 = static_cast(val.z); + f3 = static_cast(val.w); + } + + // Apply the scaling. + f0 *= scale; + f1 *= scale; + f2 *= scale; + f3 *= scale; + + // Convert the 4 floats to char4. + uint32_t dst = float4_to_char4(f0, f1, f2, f3); + + // Store the result. + int jj = mi * STGS_PER_LOOP + ii; + if (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_)) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, dst); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + // Do the reduction. + uint4 tmp[STGS_PER_LOOP]; +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int4 const& src_ii = reinterpret_cast(src[ii]); + int4 const& old_ii = reinterpret_cast(old[ii]); + + int32_t x = src_ii.x + old_ii.x; + int32_t y = src_ii.y + old_ii.y; + int32_t z = src_ii.z + old_ii.z; + int32_t w = src_ii.w + old_ii.w; + + tmp[ii].x = reinterpret_cast(x); + tmp[ii].y = reinterpret_cast(y); + tmp[ii].z = reinterpret_cast(z); + tmp[ii].w = reinterpret_cast(w); + } + + // The last CTA stores INT8 values to the final location. + if (blockIdx.x == CTAS_PER_HEAD - 1) { + this->store(tmp, mi); + + // Other CTAs store INT32 values to the scratch space. + } else if (ALL_THREADS_ACTIVE || is_active_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + fmha::stg(o_scratch_ptr_ + ii * ACTIVE_THREADS, tmp[ii]); + } + } + } + + // Move the pointer. + inline __device__ void move() { o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_o_stride_in_bytes_; + // The scaling factor to convert to int8. + uint32_t const params_scale_bmm2_; + // Do we enable the i2f trick? + bool const params_enable_i2f_trick_; + // The pointer. + char* o_ptr_; + // The scratch pointer for 32-bit reductions. + int32_t* o_scratch_ptr_; + + // Is it an active thread? When ROWS_PER_STG > ROWS_PER_LOOP, some threads do not store. + int is_active_, is_active_for_last_stg_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Imma_gmem_tile_o { + // The traits class. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The base class. + using Base = Imma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info.bidx, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Imma_gmem_tile_o { + // The traits class. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Imma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info.bidx, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace v1 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_o_packed.h b/csrc/fmha_v2/fmha/gmem_tile_o_packed.h new file mode 100644 index 0000000000..dc13b37f19 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_o_packed.h @@ -0,0 +1,1349 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include + +namespace fmha { +namespace v2 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hmma_gmem_tile_o { + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = BYTES_PER_ELEMENT_ }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The valid size of a row in bytes. + // Note: cross-attention kernels rely on head dim from runtime instead of from compile-time. + // This approach deviates from self-attention kernels. To explore a unified approach. + // enum { VALID_BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = BYTES_PER_STG_ }; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + // END OF DEBUG. + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + + // Ctor. + template + inline __device__ Hmma_gmem_tile_o(Params const& params, Block_info const& binfo, int tidx, + int cta_row_offset, int cta_col_offset_in_bytes = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(binfo.actual_q_seqlen), + o_ptr_(reinterpret_cast(params.o_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is that thread active on the last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Store the row/col to update the predicates in load. + row_ = cta_row_offset + row; + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_STG; + init_row_ = row_; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row_ * params.o_stride_in_bytes; + // The amount of bytes per row without padding. + int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT; + // Take the batch/head offset into account. TODO: Fix me! + // + // row_offset += binfo.bidx * VALID_BYTES_PER_ROW; + // + row_offset += binfo.bidx * valid_bytes_per_row; + + // Assemble the final pointer. + o_ptr_ += row_offset + col_in_bytes_; + init_o_ptr_ = o_ptr_; + + // Do not store if the thread is in the padded area + active_ = col_in_bytes_ < valid_bytes_per_row; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + if (blockIdx.x == 0) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + dst[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } else { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= actual_seqlen_) { + break; + } + if (active_ && (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_))) { + fmha::ldg(dst[ii], o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_); + } + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= actual_seqlen_) { + break; + } + if (active_ && (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_))) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, src[ii]); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + uint4 tmp[STGS_PER_LOOP]; +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + tmp[ii].x = fmha::hadd2(src[ii].x, old[ii].x); + tmp[ii].y = fmha::hadd2(src[ii].y, old[ii].y); + tmp[ii].z = fmha::hadd2(src[ii].z, old[ii].z); + tmp[ii].w = fmha::hadd2(src[ii].w, old[ii].w); + } + this->store(tmp, mi); + } + + // Move the pointer to the next location. + inline __device__ void move(int const steps = 1) { + row_ += ROWS * steps; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps; + } + + inline __device__ void move_to(int const step) { + row_ = init_row_ + ROWS * step; + o_ptr_ = init_o_ptr_ + (int64_t)ROWS * params_o_stride_in_bytes_ * step; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The pointer. + char* o_ptr_; + char* init_o_ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; + + // The row loaded by this thread. + int row_, col_in_bytes_; + int init_row_; + // The length of the sequence loaded by that CTA. + int actual_seqlen_; + // Is that thread active when it comes to loading data? + int active_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Volta_hmma_fp16_16x16x16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Turing_hmma_fp16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_fp16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_bf16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // The epilogue data type + using Epilogue_type = typename Traits::Epilogue_type; + + // DEBUG. + static_assert((Base::THREADS_PER_ROW == 16 || Base::THREADS_PER_ROW == 32 || + Base::THREADS_PER_ROW == 64 || Base::THREADS_PER_ROW == 128) && + Base::BYTES_PER_STG == 8, + ""); + + // END OF DEBUG. + + enum { STGS_PER_LOOP = Base::STGS_PER_LOOP }; + + enum { ROWS_PER_STG = Base::ROWS_PER_STG }; + + enum { STGS = Base::STGS }; + + enum { HAS_INCOMPLETE_STG = Base::HAS_INCOMPLETE_STG }; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} + + // Load data from global memory. + inline __device__ void load(uint4 const (&dst)[STGS_PER_LOOP], int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_) { + break; + } + + float x = reinterpret_cast(src[ii].x); + float y = reinterpret_cast(src[ii].y); + float z = reinterpret_cast(src[ii].z); + float w = reinterpret_cast(src[ii].w); + + uint2 out = float4_to_16bit_x4(x, y, z, w); + if (this->active_ && + (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_))) { + fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Hmma_gmem_tile_o { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Hmma_gmem_tile_o; + + // The epilogue data type + using Epilogue_type = typename Traits::Epilogue_type; + + // DEBUG. + static_assert((Base::THREADS_PER_ROW == 16 || Base::THREADS_PER_ROW == 32 || + Base::THREADS_PER_ROW == 64 || Base::THREADS_PER_ROW == 128) && + Base::BYTES_PER_STG == 8, + ""); + + // END OF DEBUG. + + enum { STGS_PER_LOOP = Base::STGS_PER_LOOP }; + + enum { ROWS_PER_STG = Base::ROWS_PER_STG }; + + enum { STGS = Base::STGS }; + + enum { HAS_INCOMPLETE_STG = Base::HAS_INCOMPLETE_STG }; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} + + // Load data from global memory. + inline __device__ void load(uint4 const (&dst)[STGS_PER_LOOP], int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int jj = mi * STGS_PER_LOOP + ii; + if (this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_) { + break; + } + + float x = reinterpret_cast(src[ii].x); + float y = reinterpret_cast(src[ii].y); + float z = reinterpret_cast(src[ii].z); + float w = reinterpret_cast(src[ii].w); + + uint2 out = float4_to_16bit_x4(x, y, z, w); + if (this->active_ && + (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_))) { + fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// I M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t quantize(int4 const val, float const scale, + bool const params_enable_i2f_trick) { + // Extract the floats and scale. + float f0, f1, f2, f3; +#if defined(USE_I2F_EMULATION_TRICK) + if (params_enable_i2f_trick) { + f0 = reinterpret_cast(val.x) - FP32_I2F_MAGIC_NUMBER; + f1 = reinterpret_cast(val.y) - FP32_I2F_MAGIC_NUMBER; + f2 = reinterpret_cast(val.z) - FP32_I2F_MAGIC_NUMBER; + f3 = reinterpret_cast(val.w) - FP32_I2F_MAGIC_NUMBER; + } else +#endif // defined(USE_I2F_EMULATION_TRICK) + { + f0 = static_cast(val.x); + f1 = static_cast(val.y); + f2 = static_cast(val.z); + f3 = static_cast(val.w); + } + + // Apply the scaling. + f0 *= scale; + f1 *= scale; + f2 *= scale; + f3 *= scale; + + // Convert the 4 floats to char4. + uint32_t dst = float4_to_char4(f0, f1, f2, f3); + + return dst; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helpers to pack 4 registers representing a Src_type into a destination register with 4 8bit +// values representing Dst_type. Scale factor is assumed to be always FP32 for 32-bit accumulators. +template +struct Acc_packer {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Signed INT32 => INT8. +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + // The accumulators are in int32_t. + int4 const& val = reinterpret_cast(src_regs); + + // Quantize... + uint32_t dst = quantize(val, scale, this_->params_enable_i2f_trick_); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + // The accumulators are in int32_t. + int4 const& val = reinterpret_cast(src_regs); + + // Quantize... + uint32_t dst = quantize(val, 1.0f, this_->params_enable_i2f_trick_); + return dst; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// FP32 => FP8. +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float const scale = reinterpret_cast(this_->params_scale_bmm2_); + + float4 const& val = reinterpret_cast(src_regs); + + uint32_t dst = + fmha::float4_to_e4m3x4(val.x * scale, val.y * scale, val.z * scale, val.w * scale); + return dst; + } + + template + static inline __device__ uint16_t run(This const* this_, uint2 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + + float2 const& val = reinterpret_cast(src_regs); + + uint16_t dst = fmha::float2_to_e4m3x2(val.x * scale, val.y * scale); + return dst; + } +}; + +// FP32 => FP8. +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float4 const& val = reinterpret_cast(src_regs); + + uint32_t dst = fmha::float4_to_e4m3x4(val.x, val.y, val.z, val.w); + return dst; + } + + template + static inline __device__ uint16_t run(This const* this_, uint2 const& src_regs) { + float2 const& val = reinterpret_cast(src_regs); + + uint16_t dst = fmha::float2_to_e4m3x2(val.x, val.y); + return dst; + } +}; + +// FP16 => FP8. +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + uint2 dst; + dst.x = fmha::half4_to_e4m3x4(fmha::hmul2(src_regs.x, this_->params_scale_bmm2_), + fmha::hmul2(src_regs.y, this_->params_scale_bmm2_)); + dst.y = fmha::half4_to_e4m3x4(fmha::hmul2(src_regs.z, this_->params_scale_bmm2_), + fmha::hmul2(src_regs.w, this_->params_scale_bmm2_)); + + return dst; + } +}; + +// FP16 => FP8. +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + uint2 dst; + dst.x = fmha::half4_to_e4m3x4(src_regs.x, src_regs.y); + dst.y = fmha::half4_to_e4m3x4(src_regs.z, src_regs.w); + + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + + float4 const& val = reinterpret_cast(src_regs); + + uint32_t dst = + fmha::float4_to_e5m2x4(val.x * scale, val.y * scale, val.z * scale, val.w * scale); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint32_t run(This const* this_, uint4 const& src_regs) { + float4 const& val = reinterpret_cast(src_regs); + + uint32_t dst = fmha::float4_to_e5m2x4(val.x, val.y, val.z, val.w); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + float4 const& val = reinterpret_cast(src_regs); + + uint2 dst = fmha::float4_to_half4(val.x, val.y, val.z, val.w); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + + float4 const& val = reinterpret_cast(src_regs); + + uint2 dst = fmha::float4_to_half4(val.x * scale, val.y * scale, val.z * scale, val.w * scale); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + float4 const& val = reinterpret_cast(src_regs); + + uint2 dst = fmha::float4_to_16bit_x4(val.x, val.y, val.z, val.w); + return dst; + } +}; + +template <> +struct Acc_packer { + template + static inline __device__ uint2 run(This const* this_, uint4 const& src_regs) { + float const& scale = reinterpret_cast(this_->params_scale_bmm2_); + + float4 const& val = reinterpret_cast(src_regs); + + uint2 dst = fmha::float4_to_16bit_x4(val.x * scale, val.y * scale, val.z * scale, + val.w * scale); + return dst; + } +}; + +// support both 32 bit accumulationi and 16 bit accumulation (imma and qmma) +template +struct Gmem_tile_o_8bit { + // static_assert(sizeof(typename Traits::Accumulator_type) == 4); + static_assert(sizeof(typename Traits::C_type) == 1); + + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = 1 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The valid size of a row in bytes. + enum { VALID_BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The size of each STG (16B --> 8bit elements). + enum { BYTES_PER_STG = fmha::Div_up<16, sizeof(typename Traits::Accumulator_type)>::VALUE }; + + // The STG packed data type + using Stg_packed_type = typename Uint_from_size_in_bytes::Type; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + +#if 0 + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS_PER_LOOP = Mma_tile::M_PER_MMA_PER_CTA }; + // The number of outer loop for the stores. + enum { LOOPS = ROWS / ROWS_PER_LOOP }; + + // Make sure the math is correct. + static_assert(LOOPS == (int)Mma_tile::MMAS_M, ""); + + // The number of "rows" stored per STG -- for it to be the number of rows per MMA instruction. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; +#endif + + // Are all threads active? + enum { ALL_THREADS_ACTIVE = ROWS_PER_STG <= ROWS_PER_LOOP }; + + // The number of active threads. + enum { ACTIVE_THREADS_ = Cta_tile::THREADS_PER_CTA * ROWS_PER_LOOP / ROWS_PER_STG }; + + // The number of active threads. + enum { ACTIVE_THREADS = ALL_THREADS_ACTIVE ? Cta_tile::THREADS_PER_CTA : ACTIVE_THREADS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_o_8bit(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(block_info.actual_q_seqlen), + params_scale_bmm2_(params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2) +#ifdef GENERATE_CUBIN + , + params_enable_i2f_trick_(false) +#else + , + params_enable_i2f_trick_(params.enable_i2f_trick) +#endif + , + o_ptr_(reinterpret_cast(params.o_ptr)) +#if USE_DEMO_BERT_PARAMS + , + o_scratch_ptr_(nullptr) { +#else + , + o_scratch_ptr_(reinterpret_cast(params.o_scratch_ptr)) { +#endif + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is it an active thread for the very last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Store the row to check against the length before loads. + row_ = cta_row_offset + row; + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_STG; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row_ * params.o_stride_in_bytes; + // The amount of bytes per row without padding (runtime). + int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT; + // Take the batch/head offset into account. + row_offset += block_info.bidx * valid_bytes_per_row; + // Assemble the final pointer. + o_ptr_ += row_offset + col_in_bytes_; + + // Is it an active thread? + is_active_ = ALL_THREADS_ACTIVE || (row < ROWS_PER_LOOP && col_in_bytes_ < VALID_BYTES_PER_ROW); + + // Do not store if the thread is in the padded area + is_active_ = is_active_ && col < valid_bytes_per_row / BYTES_PER_STG; + + // For the scratch space, the pointer has int32 type so it accounts for the *4 factor. + o_scratch_ptr_ += blockIdx.y * STGS_PER_LOOP * ACTIVE_THREADS + tidx; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + if (blockIdx.x == 0) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + dst[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } else if (ALL_THREADS_ACTIVE || is_active_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + fmha::ldg(dst[ii], o_scratch_ptr_ + ii * ACTIVE_THREADS); + } + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +// Iterate over the different STGs. +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + // Break early if we exceed s_i... + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= actual_seqlen_) { + return; + } + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; + // Packs the 32bit/16bit values to 8bit. + // Depending on the type, applies extra scaling with parameter scale_bmm2. + Stg_packed_type dst = Acc_packer::run(this, src[ii]); + float const* row_ptr = reinterpret_cast(&src[ii]); + + // Store the result. + if (is_active_ && (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_))) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, dst); + } + } + } + + // Store data to global memory. + // TODO: 16bit (half) + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + // Do the reduction. + uint4 tmp[STGS_PER_LOOP]; +#if defined(USE_I2F_EMULATION_TRICK) + if (params_enable_i2f_trick_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + float4 const& src_ii = reinterpret_cast(src[ii]); + float4 const& old_ii = reinterpret_cast(old[ii]); + + float x = src_ii.x + old_ii.x; + float y = src_ii.y + old_ii.y; + float z = src_ii.z + old_ii.z; + float w = src_ii.w + old_ii.w; + + tmp[ii].x = reinterpret_cast(x); + tmp[ii].y = reinterpret_cast(y); + tmp[ii].z = reinterpret_cast(z); + tmp[ii].w = reinterpret_cast(w); + } + } else +#endif + { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + int4 const& src_ii = reinterpret_cast(src[ii]); + int4 const& old_ii = reinterpret_cast(old[ii]); + + int32_t x = src_ii.x + old_ii.x; + int32_t y = src_ii.y + old_ii.y; + int32_t z = src_ii.z + old_ii.z; + int32_t w = src_ii.w + old_ii.w; + + tmp[ii].x = reinterpret_cast(x); + tmp[ii].y = reinterpret_cast(y); + tmp[ii].z = reinterpret_cast(z); + tmp[ii].w = reinterpret_cast(w); + } + } + + // The last CTA stores INT8 values to the final location. + if (blockIdx.x == CTAS_PER_HEAD - 1) { + this->store(tmp, mi); + + // Other CTAs store INT32 values to the scratch space. + } else if (ALL_THREADS_ACTIVE || is_active_) { +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + fmha::stg(o_scratch_ptr_ + ii * ACTIVE_THREADS, tmp[ii]); + } + } + } + + // Move the pointer. + inline __device__ void move() { + row_ += ROWS; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The scaling factor to convert to int8. + uint32_t const params_scale_bmm2_; + // Do we enable the i2f trick? + bool const params_enable_i2f_trick_; + // The pointer. + char* o_ptr_; + // The pointer to the scratch space to do the reduction (for CTAS_PER_HEAD > 1). + uint4* o_scratch_ptr_; + // The row, col stored by this thread (i.e. the position in that sequence). + int row_, col_in_bytes_; + // The size of the sequence length computed by that CTA. + int actual_seqlen_; + + // Is it an active thread? + int is_active_, is_active_for_last_stg_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Volta_imma_int8_int32_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o + : public Gmem_tile_o_8bit { + // The traits class. + using Traits = fmha::Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_16bit { + // This stores the fp32 accumulators of Ada_qmma_e4m3_fp32_traits as 16bit values to + // the global memory. + + static_assert(std::is_same::value); + static_assert(std::is_same::value || + std::is_same::value); + + using Mma_tile = typename Traits::template Mma_tile; + + // The size of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The valid size of a row in bytes. + // Note: cross-attention kernels rely on head dim from runtime instead of from compile-time. + // This approach deviates from self-attention kernels. To explore a unified approach. + enum { VALID_BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = 8 }; + + // The STG packed data type + using Stg_packed_type = typename Uint_from_size_in_bytes::Type; + + // The number of threads to store a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // The number of "rows" stored per STG. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loop for the stores. + enum { LOOPS = fmha::Div_up::VALUE }; + + // DEBUG. + static_assert(ROWS % ROWS_PER_LOOP == 0, ""); + // END OF DEBUG. + + // Make sure the math is correct. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + // The number of STGs needed to store a chunk of the Q matrix in total. + enum { STGS = STGS_PER_LOOP * LOOPS }; + + // Are all threads active? + enum { ALL_THREADS_ACTIVE = ROWS_PER_STG <= ROWS_PER_LOOP }; + + // The number of active threads. + enum { ACTIVE_THREADS_ = Cta_tile::THREADS_PER_CTA * ROWS_PER_LOOP / ROWS_PER_STG }; + + // The number of active threads. + enum { ACTIVE_THREADS = ALL_THREADS_ACTIVE ? Cta_tile::THREADS_PER_CTA : ACTIVE_THREADS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_o_16bit(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(block_info.actual_q_seqlen), + params_scale_bmm2_(params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2) +#ifdef GENERATE_CUBIN + , + params_enable_i2f_trick_(false) +#else + , + params_enable_i2f_trick_(params.enable_i2f_trick) +#endif + , + o_ptr_(reinterpret_cast(params.o_ptr)) +#if USE_DEMO_BERT_PARAMS + , + o_scratch_ptr_(nullptr) { +#else + , + o_scratch_ptr_(reinterpret_cast(params.o_scratch_ptr)) { +#endif + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Is it an active thread for the very last STG? + if (HAS_INCOMPLETE_STG) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + + // Store the row to check against the length before loads. + row_ = cta_row_offset + row; + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_STG; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)row_ * params.o_stride_in_bytes; + // The amount of bytes per row without padding (runtime). + int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT; + // Take the batch/head offset into account. + row_offset += block_info.bidx * valid_bytes_per_row; + // Assemble the final pointer. + o_ptr_ += row_offset + col_in_bytes_; + + // Is it an active thread? + is_active_ = ALL_THREADS_ACTIVE || (row < ROWS_PER_LOOP && col_in_bytes_ < VALID_BYTES_PER_ROW); + + // Do not store if the thread is in the padded area + is_active_ = is_active_ && col < valid_bytes_per_row / BYTES_PER_STG; + + // For the scratch space, the pointer has int32 type so it accounts for the *4 factor. + o_scratch_ptr_ += blockIdx.y * STGS_PER_LOOP * ACTIVE_THREADS + tidx; + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { +// Iterate over the different STGs. +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + // Break early if we exceed s_i... + int jj = mi * STGS_PER_LOOP + ii; + if (row_ + jj * ROWS_PER_STG >= actual_seqlen_) { + return; + } + using Src_type = typename Traits::Accumulator_type; + // Packs the 32bit/16bit values to 16bit. + // Depending on the type, applies extra scaling with parameter scale_bmm2. + Stg_packed_type dst = Acc_packer::run(this, src[ii]); + float const* row_ptr = reinterpret_cast(&src[ii]); + + // Store the result. + if (is_active_ && (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || is_active_for_last_stg_))) { + fmha::stg(o_ptr_ + jj * ROWS_PER_STG * params_o_stride_in_bytes_, dst); + } + } + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The scaling factor to convert to int8. + uint32_t const params_scale_bmm2_; + // Do we enable the i2f trick? + bool const params_enable_i2f_trick_; + // The pointer. + char* o_ptr_; + // The pointer to the scratch space to do the reduction (for CTAS_PER_HEAD > 1). + uint4* o_scratch_ptr_; + // The row, col stored by this thread (i.e. the position in that sequence). + int row_, col_in_bytes_; + // The size of the sequence length computed by that CTA. + int actual_seqlen_; + + // Is it an active thread? + int is_active_, is_active_for_last_stg_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_uint16 : public Gmem_tile_o_16bit { + using Base = Gmem_tile_o_16bit; + + // Ctor. + template + inline __device__ Gmem_tile_o_uint16(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_bfloat16 + : public Gmem_tile_o_16bit { + using Base = Gmem_tile_o_16bit; + + // Ctor. + template + inline __device__ Gmem_tile_o_bfloat16(Params const& params, Block_info const& block_info, + int tidx, int cta_row_offset = 0, + int cta_col_offset_in_bytes = 0) + : Base(params, block_info, tidx, cta_row_offset, cta_col_offset_in_bytes) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Imma_gmem_tile_o_interleaved { + // The mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + enum { VEC = 32 }; + + enum { NUM_SLICES = Cta_tile::N / VEC }; + + // DEBUG. + static_assert(NUM_SLICES == 1 || NUM_SLICES == 2, ""); + + // END OF DEBUG. + + // The size of each element. + enum { BYTES_PER_ELEMENT = 1 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = VEC * BYTES_PER_ELEMENT }; + + // The size of each STG. + enum { BYTES_PER_STG = 4 }; + + // The number of threads to store a "row" of the matrix. We force it to 8 + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG }; + + // DEBUG. + static_assert(THREADS_PER_ROW == 8 && BYTES_PER_STG == 4, ""); + + // END OF DEBUG. + + // the "logical" number of rows. think of rows per slice + enum { ROWS = Cta_tile::M }; + + // "physical" rows + enum { TOTAL_ROWS = ROWS * NUM_SLICES }; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + enum { ROWS_PER_LOOP_PER_SLICE = Mma_tile::M_PER_MMA_PER_CTA }; + + enum { ROWS_PER_LOOP = Mma_tile::M_PER_MMA_PER_CTA * NUM_SLICES }; + + // DEBUG. + static_assert(ROWS_PER_LOOP == 16 * Cta_tile::WARPS_M * NUM_SLICES, ""); + + // END OF DEBUG. + + // The number of outer loop for the stores. + enum { LOOPS = TOTAL_ROWS / ROWS_PER_LOOP }; + + // Make sure the math is correct. + static_assert(LOOPS == (int)Mma_tile::MMAS_M, ""); + + // The number of "rows" stored per STG -- for it to be the number of rows per MMA instruction. + enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of STGs needed to store a chunk of the Q matrix. + enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; + + enum { STGS_PER_SLICE = STGS_PER_LOOP / NUM_SLICES }; + + // DEBUG. + static_assert((Cta_tile::WARPS_M == 1 && STGS_PER_SLICE == 1) || + (Cta_tile::WARPS_M == 2 && STGS_PER_SLICE == 2), + ""); + + // END OF DEBUG. + + // Ctor. + template + inline __device__ Imma_gmem_tile_o_interleaved(Params const& params, Block_info const& block_info, + int tidx, int cta_row_offset = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(block_info.actual_seqlen - cta_row_offset), + params_scale_bmm2_(params.scale_bmm2), + params_enable_i2f_trick_(params.enable_i2f_trick), + o_ptr_(reinterpret_cast(params.o_ptr)), + total_(params.o_stride_in_bytes) { + int bidh = block_info.bidh; + int sum_s = block_info.sum_s; + + row_ = tidx / THREADS_PER_ROW; + int col = tidx % THREADS_PER_ROW; + + // h is N + // d is H + // want to save as: h x (d/32) x total x 32 (think 3 x h x (d/32) x b x s x 32) + + int block_offset = bidh * NUM_SLICES * total_ + sum_s; // bidh * GROUPS * B * S + b * S + int row_offset = (block_offset + cta_row_offset) * BYTES_PER_ROW; + + o_ptr_ += row_offset + col * BYTES_PER_STG; + } + + // Load data from global memory. + inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { + int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG; + int rows_so_far_per_slice = rows_so_far / 2; + + // The scale. + float const& scale = reinterpret_cast(params_scale_bmm2_); + +// Iterate over the different STGs. +#pragma unroll + for (int ii = 0; ii < STGS_PER_LOOP; ++ii) { + // if(ii == 1) return; + // decompose the iteration into slice + int slice = ii / STGS_PER_SLICE; + int si = ii % STGS_PER_SLICE; + // dbg 256 + // assert(STGS_PER_SLICE == 1); + // assert(STGS_PER_LOOP == 2); + // assert(slice == ii); + // the number of rows one CTA-wide STG writes + static_assert(ROWS_PER_STG == 16, ""); // only holds for 4 warps/128 threads + int row_in_slice = row_ + si * ROWS_PER_STG + rows_so_far_per_slice; + + // we cannot return early, because the second half of iterates are + // responsible for the bottom slice + if (row_in_slice >= min(actual_seqlen_, ROWS)) { + continue; + } + + int offset = (slice * total_ + row_in_slice) * BYTES_PER_ROW; + + // The accumulators are in int32_t. + int4 const& val = reinterpret_cast(src[ii]); + + // if(threadIdx.x == 96){ + // printf("mi=%d ii=%d S=%d si=%d sofar=%d row=%d as=%d\n", mi, ii, slice, si, + // rows_so_far_per_slice, row_in_slice, actual_seqlen_) ; + // } + + uint32_t dst = quantize(val, scale, params_enable_i2f_trick_); + // Store the result. + fmha::stg(o_ptr_ + offset, dst); + } + } + + // Store data to global memory. + inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], uint4 const (&old)[STGS_PER_LOOP], + int mi) { + static_assert(CTAS_PER_HEAD == 1, "Not implemented"); + } + + // Move the pointer. + inline __device__ void move() { + o_ptr_ += (int64_t)ROWS * BYTES_PER_ROW; + actual_seqlen_ -= ROWS; + } + + // The stride between rows for the QKV matrice. + int64_t const params_o_stride_in_bytes_; + // The scaling factor to convert to int8. + uint32_t const params_scale_bmm2_; + // Do we enable the i2f trick? + bool const params_enable_i2f_trick_; + // The pointer. + char* o_ptr_; + int row_; + int actual_seqlen_; + int total_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace v2 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_ps.h b/csrc/fmha_v2/fmha/gmem_tile_ps.h new file mode 100644 index 0000000000..de150ff293 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_ps.h @@ -0,0 +1,837 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator { + // The fragment. + using Acc = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Acc const& acc, + uint32_t scale) { + uint32_t acc_0 = fmha::hmul2(acc.reg(0), scale); + uint32_t acc_1 = fmha::hmul2(acc.reg(1), scale); + uint32_t acc_2 = fmha::hmul2(acc.reg(2), scale); + uint32_t acc_3 = fmha::hmul2(acc.reg(3), scale); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, acc_0); + fmha::stg(ptr + 1 * step_m + 0 * step_n, acc_1); + fmha::stg(ptr + 0 * step_m + 1 * step_n, acc_2); + fmha::stg(ptr + 1 * step_m + 1 * step_n, acc_3); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator { + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { + int32_t tmp_0 = acc.elt(0); + int32_t tmp_1 = acc.elt(1); + int32_t tmp_2 = acc.elt(2); + int32_t tmp_3 = acc.elt(3); + int32_t tmp_4 = acc.elt(4); + int32_t tmp_5 = acc.elt(5); + int32_t tmp_6 = acc.elt(6); + int32_t tmp_7 = acc.elt(7); + +#if defined(USE_I2F_EMULATION_TRICK) + tmp_0 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_1 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_2 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_3 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_4 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_5 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_6 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_7 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); +#endif + + uint32_t acc_0 = reinterpret_cast(tmp_0); + uint32_t acc_1 = reinterpret_cast(tmp_1); + uint32_t acc_2 = reinterpret_cast(tmp_2); + uint32_t acc_3 = reinterpret_cast(tmp_3); + uint32_t acc_4 = reinterpret_cast(tmp_4); + uint32_t acc_5 = reinterpret_cast(tmp_5); + uint32_t acc_6 = reinterpret_cast(tmp_6); + uint32_t acc_7 = reinterpret_cast(tmp_7); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, make_uint2(acc_0, acc_1)); + fmha::stg(ptr + 1 * step_m + 0 * step_n, make_uint2(acc_4, acc_5)); + fmha::stg(ptr + 0 * step_m + 1 * step_n, make_uint2(acc_2, acc_3)); + fmha::stg(ptr + 1 * step_m + 1 * step_n, make_uint2(acc_6, acc_7)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Store_accumulator { + // The instruction traits. + using Traits = Ampere_hmma_fp32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { + float const scalef = reinterpret_cast(scale); + + float const tmp_0 = acc.elt(0) * scalef; + float const tmp_1 = acc.elt(1) * scalef; + float const tmp_2 = acc.elt(2) * scalef; + float const tmp_3 = acc.elt(3) * scalef; + float const tmp_4 = acc.elt(4) * scalef; + float const tmp_5 = acc.elt(5) * scalef; + float const tmp_6 = acc.elt(6) * scalef; + float const tmp_7 = acc.elt(7) * scalef; + + uint32_t acc_0 = reinterpret_cast(tmp_0); + uint32_t acc_1 = reinterpret_cast(tmp_1); + uint32_t acc_2 = reinterpret_cast(tmp_2); + uint32_t acc_3 = reinterpret_cast(tmp_3); + uint32_t acc_4 = reinterpret_cast(tmp_4); + uint32_t acc_5 = reinterpret_cast(tmp_5); + uint32_t acc_6 = reinterpret_cast(tmp_6); + uint32_t acc_7 = reinterpret_cast(tmp_7); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, make_uint2(acc_0, acc_1)); + fmha::stg(ptr + 1 * step_m + 0 * step_n, make_uint2(acc_2, acc_3)); + fmha::stg(ptr + 0 * step_m + 1 * step_n, make_uint2(acc_4, acc_5)); + fmha::stg(ptr + 1 * step_m + 1 * step_n, make_uint2(acc_6, acc_7)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Store_accumulator { + // The instruction traits. + using Traits = Ampere_hmma_bf16_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { + float const scalef = reinterpret_cast(scale); + + float const tmp_0 = acc.elt(0) * scalef; + float const tmp_1 = acc.elt(1) * scalef; + float const tmp_2 = acc.elt(2) * scalef; + float const tmp_3 = acc.elt(3) * scalef; + float const tmp_4 = acc.elt(4) * scalef; + float const tmp_5 = acc.elt(5) * scalef; + float const tmp_6 = acc.elt(6) * scalef; + float const tmp_7 = acc.elt(7) * scalef; + + uint32_t acc_0 = reinterpret_cast(tmp_0); + uint32_t acc_1 = reinterpret_cast(tmp_1); + uint32_t acc_2 = reinterpret_cast(tmp_2); + uint32_t acc_3 = reinterpret_cast(tmp_3); + uint32_t acc_4 = reinterpret_cast(tmp_4); + uint32_t acc_5 = reinterpret_cast(tmp_5); + uint32_t acc_6 = reinterpret_cast(tmp_6); + uint32_t acc_7 = reinterpret_cast(tmp_7); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, make_uint2(acc_0, acc_1)); + fmha::stg(ptr + 1 * step_m + 0 * step_n, make_uint2(acc_2, acc_3)); + fmha::stg(ptr + 0 * step_m + 1 * step_n, make_uint2(acc_4, acc_5)); + fmha::stg(ptr + 1 * step_m + 1 * step_n, make_uint2(acc_6, acc_7)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t pack_char2(uint32_t a, uint32_t b) { + uint32_t dst; + asm volatile("prmt.b32 %0, %1, %2, 0x0040;\n" : "=r"(dst) : "r"(a), "r"(b)); + return reinterpret_cast(dst); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator { + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { + // Pack pairs of values. + uint16_t tmp_00 = pack_char2(acc.reg(0), acc.reg(1)); + uint16_t tmp_01 = pack_char2(acc.reg(2), acc.reg(3)); + uint16_t tmp_10 = pack_char2(acc.reg(4), acc.reg(5)); + uint16_t tmp_11 = pack_char2(acc.reg(6), acc.reg(7)); + + // Store to memory. + fmha::stg(ptr + 0 * step_m + 0 * step_n, tmp_00); + fmha::stg(ptr + 1 * step_m + 0 * step_n, tmp_10); + fmha::stg(ptr + 0 * step_m + 1 * step_n, tmp_01); + fmha::stg(ptr + 1 * step_m + 1 * step_n, tmp_11); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Store_accumulator { + // The traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { + int32_t tmp_0 = acc.elt(0); + int32_t tmp_1 = acc.elt(1); + int32_t tmp_2 = acc.elt(2); + int32_t tmp_3 = acc.elt(3); + int32_t tmp_4 = acc.elt(4); + int32_t tmp_5 = acc.elt(5); + int32_t tmp_6 = acc.elt(6); + int32_t tmp_7 = acc.elt(7); + +#if defined(USE_I2F_EMULATION_TRICK) + tmp_0 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_1 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_2 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_3 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_4 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_5 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_6 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); + tmp_7 -= int32_t(FP32_I2F_MAGIC_NUMBER_HEX); +#endif + + uint32_t acc_0 = reinterpret_cast(tmp_0); + uint32_t acc_1 = reinterpret_cast(tmp_1); + uint32_t acc_2 = reinterpret_cast(tmp_2); + uint32_t acc_3 = reinterpret_cast(tmp_3); + uint32_t acc_4 = reinterpret_cast(tmp_4); + uint32_t acc_5 = reinterpret_cast(tmp_5); + uint32_t acc_6 = reinterpret_cast(tmp_6); + uint32_t acc_7 = reinterpret_cast(tmp_7); + + fmha::stg(ptr + 0 * step_m + 0 * step_n, make_uint2(acc_0, acc_1)); + fmha::stg(ptr + 1 * step_m + 0 * step_n, make_uint2(acc_2, acc_3)); + fmha::stg(ptr + 0 * step_m + 1 * step_n, make_uint2(acc_4, acc_5)); + fmha::stg(ptr + 1 * step_m + 1 * step_n, make_uint2(acc_6, acc_7)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Store_accumulator { + // The traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { + // Pack pairs of values. + uint16_t tmp_00 = pack_char2(acc.reg(0), acc.reg(1)); + uint16_t tmp_01 = pack_char2(acc.reg(4), acc.reg(5)); + uint16_t tmp_10 = pack_char2(acc.reg(2), acc.reg(3)); + uint16_t tmp_11 = pack_char2(acc.reg(6), acc.reg(7)); + + // Store to memory. + fmha::stg(ptr + 0 * step_m + 0 * step_n, tmp_00); + fmha::stg(ptr + 1 * step_m + 0 * step_n, tmp_10); + fmha::stg(ptr + 0 * step_m + 1 * step_n, tmp_01); + fmha::stg(ptr + 1 * step_m + 1 * step_n, tmp_11); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_hgmma_fp16_traits, 16> { + // The traits. + using Traits = fmha::Hopper_hgmma_fp16_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 4 / 2 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + uint32_t acc_0 = fmha::hmul2(acc.reg(col_idx * ROWS_PER_THREAD + row_idx), scale); + // float one = 1.f; + // if(col_idx > 2){ + // acc_0 = float2_to_half2(one, one); + // } + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_0); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_qgmma_fp8_fp32_traits, + 32> { + // The traits. + using Traits = fmha::Hopper_qgmma_fp8_fp32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 8 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { + float const scalef = reinterpret_cast(scale); +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + float const acc_0 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 0) * scalef; + float const acc_1 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 1) * scalef; + uint2 acc_; + acc_.x = reinterpret_cast(acc_0); + acc_.y = reinterpret_cast(acc_1); + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_igmma_int8_int32_traits, 32> { + // The traits. + using Traits = fmha::Hopper_igmma_int8_int32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 8 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t scale) { +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + int32_t const acc_0 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 0); + int32_t const acc_1 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 1); + uint2 acc_; + acc_.x = reinterpret_cast(acc_0); + acc_.y = reinterpret_cast(acc_1); + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static __device__ inline uint16_t pack_e4m3x2(float const x, float const y) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + uint16_t storage; + asm volatile("{cvt.rn.satfinite.e4m3x2.f32 %0, %2, %1;}\n" : "=h"(storage) : "f"(x), "f"(y)); + return storage; +#else + assert(false); + return 0; +#endif +} + +static __device__ inline uint16_t pack_e5m2x2(float const x, float const y) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + uint16_t storage; + asm volatile("{cvt.rn.satfinite.e5m2x2.f32 %0, %2, %1;}\n" : "=h"(storage) : "f"(x), "f"(y)); + return storage; +#else + assert(false); + return 0; +#endif +} + +template +__device__ inline uint16_t pack_fp8x2(float const x, float const y); + +template <> +__device__ inline uint16_t pack_fp8x2(float const x, float const y) { + return pack_e4m3x2(x, y); +} + +template <> +__device__ inline uint16_t pack_fp8x2(float const x, float const y) { + return pack_e5m2x2(x, y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_qgmma_fp8_fp32_traits, + 8> { + // The traits. + using Traits = fmha::Hopper_qgmma_fp8_fp32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 8 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + float const acc_0 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 0); + float const acc_1 = acc.elt((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 1); + // uint16_t acc_ = pack_e4m3x2(acc_0, acc_1); + uint16_t acc_ = pack_fp8x2(acc_0, acc_1); + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store_accumulator< + fmha::Hopper_igmma_int8_int32_traits, 8> { + // The traits. + using Traits = fmha::Hopper_igmma_int8_int32_traits; + // The fragment. + using Accumulator = fmha::Fragment_accumulator; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = GMMA_M / 8 / 4 }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLUMNS_PER_THREAD = GMMA_N / 8 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELEMENT_PER_THREAD = ROWS_PER_THREAD * COLUMNS_PER_THREAD }; + + // Store. + inline __device__ void store(char* ptr, int64_t step_m, int64_t step_n, Accumulator const& acc, + uint32_t) { +#pragma unroll + for (int col_idx = 0; col_idx < COLUMNS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + uint32_t const acc_0 = acc.reg((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 0); + uint32_t const acc_1 = acc.reg((col_idx * ROWS_PER_THREAD + row_idx) * 2 + 1); + uint16_t acc_ = pack_char2(acc_0, acc_1); + int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; + fmha::stg(ptr + offset, acc_); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_ps { + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 2 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // // DEBUG. + // static_assert(BYTES_PER_ROW == 384 || BYTES_PER_ROW == 768 || BYTES_PER_ROW == 1536, ""); + // // END OF DEBUG. + + // Ctor. + inline __device__ Gmem_tile_ps(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx, int cta_row_offset = 0) + : params_stride_in_bytes_(params_stride_in_bytes), + params_scale_(params_scale), + ptr_(reinterpret_cast(ptr)) { + // For storing P and S, we do not take into account variable sequence length. + + // The block index for the batch. + int const bidb = blockIdx.y; + // The block index for the head. + int const bidh = blockIdx.x; + // The block index. + int bidx = bidb * gridDim.x + bidh; + + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Compute the position in the sequence (within the CTA for the moment). + int row = warp % Cta_tile::WARPS_M * Mma_tile::M_PER_MMA + lane / 4 + cta_row_offset; + // Compute the position of the thread in the row. + int col = warp / Cta_tile::WARPS_M * Mma_tile::N_PER_MMA + lane % 4 * ELEMENTS_PER_STG; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + bidx * BYTES_PER_ROW; + // Finalize the pointer. + ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + // A thread holds packet of 2 elements. In 2x2 tile per MMA. + int64_t const step_m = 8 * params_stride_in_bytes_; + int64_t const step_n = 8 * BYTES_PER_ELEMENT; + +// Store the different accumulators. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + int64_t offset = (int64_t)mi * Mma_tile::M_PER_MMA_PER_CTA * params_stride_in_bytes_ + + ni * Mma_tile::N_PER_MMA_PER_CTA * BYTES_PER_ELEMENT; + Store_accumulator delegate; + delegate.store(ptr_ + offset, step_m, step_n, acc[mi][ni], params_scale_); + } + } + } + + // Move to the next location. + inline __device__ void move() { ptr_ += (int64_t)Cta_tile::M * params_stride_in_bytes_; } + + inline __device__ void move_n() { ptr_ += (int64_t)Cta_tile::N * BYTES_PER_ELEMENT; } + + // The stride between rows for the QKV matrice. + int64_t const params_stride_in_bytes_; + // The scale to apply before storing the element. + uint32_t const params_scale_; + // The pointer. + char* ptr_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_ps { + // The traits class. + using Traits = Volta_hmma_fp16_traits; + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 4 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // Ctor. + inline __device__ Gmem_tile_ps(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx, int cta_row_offset = 0) + : params_stride_in_bytes_(params_stride_in_bytes), + params_scale_(params_scale), + ptr_(reinterpret_cast(ptr)) { + // For storing P and S, we do not take into account variable sequence lengths. + + // The block index for the batch. + int const bidb = blockIdx.y; + // The block index for the head. + int const bidh = blockIdx.x; + // The block index. + int bidx = bidb * gridDim.x + bidh; + + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // DEBUG. + static_assert(Mma_tile::M_PER_MMA == 16 && Mma_tile::N_PER_MMA == 16, ""); + // END OF DEBUG. + + // The position of the warp. + int warp_row = warp % Cta_tile::WARPS_M * Mma_tile::M_PER_MMA; + int warp_col = warp / Cta_tile::WARPS_M * Mma_tile::N_PER_MMA; + + // Compute the position of the thread (within the CTA for the moment). + int row = warp_row + (lane & 0x10) / 2 + (lane & 0x07); + int col = warp_col + (lane & 0x08) / 2; + + // // DEBUG. + // printf("tidx=%3d row=%3d col=%3d\n", tidx, row, col); + // // END OF DEBUG. + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = + (int64_t)row * params_stride_in_bytes_ + bidx * BYTES_PER_ROW + cta_row_offset; + + // Finalize the pointer. + ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { +// Store the different accumulators. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + // Scale the accumulators. + uint32_t acc_0 = fmha::hmul2(acc[mi][ni].reg(0), params_scale_); + uint32_t acc_1 = fmha::hmul2(acc[mi][ni].reg(1), params_scale_); + uint32_t acc_2 = fmha::hmul2(acc[mi][ni].reg(2), params_scale_); + uint32_t acc_3 = fmha::hmul2(acc[mi][ni].reg(3), params_scale_); + + // The offsets. + int row = mi * Mma_tile::M_PER_MMA_PER_CTA; + int col = ni * Mma_tile::N_PER_MMA_PER_CTA * BYTES_PER_ELEMENT; + + // The offset in bytes. + int64_t offset = (int64_t)row * params_stride_in_bytes_ + col; + + // In one MMA, 16 FP16s are interleaved between threads i and i+8 in groups of 4. + fmha::stg(&ptr_[offset + 0 * BYTES_PER_ELEMENT], make_uint2(acc_0, acc_1)); + fmha::stg(&ptr_[offset + 8 * BYTES_PER_ELEMENT], make_uint2(acc_2, acc_3)); + } + } + } + + // Move to the next location. + inline __device__ void move() { ptr_ += (int64_t)Cta_tile::M * params_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_stride_in_bytes_; + // The scale to apply before storing the element. + uint32_t const params_scale_; + // The pointer. + char* ptr_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_p : public Gmem_tile_ps { + // The base class. + using Base = Gmem_tile_ps; + + // Ctor. + inline __device__ Gmem_tile_p(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx, int cta_row_offset = 0) + : Base(ptr, params_stride_in_bytes, params_scale, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Not super proud of this. Need to refactor. +template +struct Gmem_tile_ps_hopper { + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 2 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // Ctor. + inline __device__ Gmem_tile_ps_hopper(void* ptr, int64_t const params_stride_in_bytes, + int64_t const bytes_per_row, uint32_t const params_scale, + int tidx) + : params_stride_in_bytes_(params_stride_in_bytes), + params_scale_(params_scale), + ptr_(reinterpret_cast(ptr)) { + // For storing P and S, we do not take into account variable sequence length. + + // The block index for the batch. + int const bidb = blockIdx.y; + // The block index for the head. + int const bidh = blockIdx.x; + // The block index. + int bidx = bidb * gridDim.x + bidh; + + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + int warpgroup_idx = warp / 4; + int warp_idx_within_warpgroup = warp % 4; + + // Compute the position in the sequence (within the CTA for the moment). + int row = warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) + lane / 4; + // Compute the position of the thread in the row. + int col = warpgroup_idx * Mma_tile::N_PER_MMA + lane % 4 * ELEMENTS_PER_STG; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + bidx * bytes_per_row; + // Finalize the pointer. + ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Ctor. + inline __device__ Gmem_tile_ps_hopper(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx) + : Gmem_tile_ps_hopper(ptr, params_stride_in_bytes, BYTES_PER_ROW, params_scale, tidx) {} + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + // A thread holds packet of 2 elements. In 2x2 tile per MMA. + // Need to figure out if we need this for hopper. + int64_t const step_m = 8 * (this->params_stride_in_bytes_); + int64_t const step_n = 8 * BYTES_PER_ELEMENT; + +// Store the different accumulators. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + int64_t offset = + (int64_t)mi * Mma_tile::M_PER_MMA_PER_CTA * (this->params_stride_in_bytes_) + + ni * Mma_tile::N_PER_MMA_PER_CTA * BYTES_PER_ELEMENT; + + Store_accumulator delegate; + delegate.store(this->ptr_ + offset, step_m, step_n, acc[mi][ni], this->params_scale_); + } + } + } + + // Move to the next location. + inline __device__ void move() { ptr_ += (int64_t)Cta_tile::M * params_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_stride_in_bytes_; + // The scale to apply before storing the element. + uint32_t const params_scale_; + // The pointer. + char* ptr_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_s : public Gmem_tile_ps { + // The base class. + using Base = Gmem_tile_ps; + + // Ctor. + inline __device__ Gmem_tile_s(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx) + : Base(ptr, params_stride_in_bytes, params_scale, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_s + : public Gmem_tile_ps { + // The base class. + using Base = Gmem_tile_ps; + + // Ctor. + inline __device__ Gmem_tile_s(void* ptr, int64_t const params_stride_in_bytes, + uint32_t const params_scale, int tidx, int cta_row_offset = 0) + : Base(ptr, params_stride_in_bytes, + float_to_half2(reinterpret_cast(params_scale)), tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_qkv.h b/csrc/fmha_v2/fmha/gmem_tile_qkv.h new file mode 100644 index 0000000000..0c0af5c8e4 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_qkv.h @@ -0,0 +1,167 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { +namespace v1 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns. + int COLS, + // The number of valid columns + int VALID_COLS, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? + bool HEADS_INTERLEAVED, + // Number of matrices + int NUM_MATS = 3, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION = false> +struct Gmem_tile_qkv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The valid number of threads to load a "row" of the matrix. + enum { VALID_THREADS_PER_ROW = VALID_BYTES_PER_ROW / BYTES_PER_LDG }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Make sure we use a single register to store predicates. + static_assert(PRED_REGS == 1, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo, + int tidx, int cta_row_offset = 0) + + // in PACKED_QKV, q_stride = k_stride = v_stride + : params_qkv_stride_in_bytes_(params.q_stride_in_bytes), + qkv_ptr_(reinterpret_cast(params.qkv_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Prepare predicates. + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row + ii * ROWS_PER_LDG < ROWS; + } + + // Pack the predicates. + preds_[0] = fmha::pack_predicates(preds); + + // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + int64_t row_offset = (int64_t)(row + cta_row_offset) * params_qkv_stride_in_bytes_; + // Add the block index. + int idx; + if (HEADS_INTERLEAVED) { + idx = binfo.bidx * NUM_MATS + qkv_offset; + } else { + idx = (params.b * params.s * NUM_MATS + qkv_offset) * params.h + binfo.bidh; + } + // Assemble the final pointer. + qkv_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col * BYTES_PER_LDG; + + // active threads + is_active_ = col < VALID_THREADS_PER_ROW; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + } + if (USE_LDGSTS) { + smem_tile.store(ptrs, preds_); + } else { + fmha::ldg(fetch_, ptrs, preds_); + } + } + + // Load data from global memory, shared mem is not needed + inline __device__ void load() { + void const* ptrs[LDGS]; + if (is_active_) { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + } + fmha::ldg(fetch_, ptrs, preds_); + } + } + + // Move the pointer to the next location. + inline __device__ void move() { qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_; } + + // The stride between rows for the QKV matrice. + int64_t const params_qkv_stride_in_bytes_; + // The pointer. + char const* qkv_ptr_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // The active LDG threads + bool is_active_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace v1 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h b/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h new file mode 100644 index 0000000000..00797d0a01 --- /dev/null +++ b/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h @@ -0,0 +1,1307 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include + +namespace fmha { +namespace v2 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ldgsts_helper { + template + static inline __device__ void load(This* this_, Smem_tile& smem_tile, void const* (&ptrs)[LDGS], + uint32_t (&preds)[LDGS]) { + fmha::pack_predicates(this_->preds_, preds); + smem_tile.store(ptrs, this_->preds_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Ldgsts_helper<0> { + template + static inline __device__ void load(This* this_, Smem_tile& smem_tile, void const* (&ptrs)[LDGS], + uint32_t (&preds)[LDGS]) { +#if 0 + fmha::pack_predicates(this_->preds_, preds); + fmha::ldg(this_->fetch_, ptrs, this_->preds_); +#else +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + this_->fetch_[ii] = make_uint4(0u, 0u, 0u, 0u); + } + // not packing predicates removes restrictions (e.g. FP16 384, 4 warps) + Ldg_functor fct(this_->fetch_, ptrs); +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + fct.ldgsts(ii, preds[ii]); + } +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT_, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns (padded, e.g 64). + int COLS, + // The actual number of columns (unpadded, e.g 40) + int VALID_COLS_, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? + bool HEADS_INTERLEAVED, + // The number of matrices + int NUM_MATS = 3, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION = false> +struct Gmem_tile_qkv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The number of bits/bytes of element + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT_ / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes (without paddings). + enum { VALID_COLS = VALID_COLS_ }; + + // The amount of bytes that are valid per row. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor for bert::Fused_multihead_attention_params_v2 class + template + inline __device__ Gmem_tile_qkv(bert::Fused_multihead_attention_params_v2 const& params, + int qkv_offset, Block_info const& binfo, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h, + qkv_offset, binfo, tidx, params.h_kv, cta_row_offset, + cta_col_offset_in_bytes) {} + + // Ctor for other param classes (such as Qkv_params in train_ops) + template + inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo, + int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h, + qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {} + + // Ctor. + template + inline __device__ Gmem_tile_qkv(void* qkv_ptr, size_t qkv_stride_in_bytes, int d, int dv, + int num_heads, int qkv_offset, Block_info const& binfo, int tidx, + int num_kv_heads = 0, int cta_row_offset = 0, + int cta_col_offset_in_bytes = 0) + : params_qkv_stride_in_bytes_(qkv_stride_in_bytes), + actual_seqlen_(binfo.actual_seqlen), + qkv_ptr_(reinterpret_cast(qkv_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // We must store the value to update the predicates in "load". + row_ = row; + // Do not load/store if the thread is in the padded area + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; + + // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + int64_t row_offset = (int64_t)(row + cta_row_offset) * params_qkv_stride_in_bytes_; + // Add the byte index. + int64_t idx; + + // Both MQA and GQA will use non HEADS_INTERLEAVED layout + if (num_kv_heads < num_heads) { + int const head_id = binfo.bidh; + int const kv_head_id = binfo.bidh / (num_heads / num_kv_heads); + // QKV layout [b, s, [q_hd, k_h'd, v_h'd]] + idx = binfo.sum_s * params_qkv_stride_in_bytes_; + if (qkv_offset == 0) { // Q tensor + idx += head_id * VALID_BYTES_PER_ROW; + } else if (qkv_offset == 1) { // K tensor + idx += (num_heads + kv_head_id) * VALID_BYTES_PER_ROW; + } else if (qkv_offset == 2) { // V tensor + /* When qkv_offset == 2, this is an instance of Gmem_tile_v defined in Kernel_traits: + using Gmem_tile_v = Gmem_tile_v_; + the 6th template argument is VALID_DV instead of VALID_D. + Thus, here VALID_COLS equals VALID_DV, and + VALID_BYTES_PER_ROW equals VALID_DV * BYTES_PER_ELEMENT, + and `kv_head_id * dv * BYTES_PER_ELEMENT` can be optimized to + `kv_head_id * VALID_BYTES_PER_ROW`. */ + idx += + (num_heads + num_kv_heads) * d * BYTES_PER_ELEMENT + kv_head_id * VALID_BYTES_PER_ROW; + } + } else if (HEADS_INTERLEAVED) { + // [b, s, h, [q_d, k_d, v_d]] aka bsh3d + // bidx = sum_s * params.h + bidh; + idx = (binfo.bidx * (2 * d + dv) + qkv_offset * d) * BYTES_PER_ELEMENT; + } else { + // [b, s, [q_hd, k_hd, v_hd]] aka bs3hd + idx = binfo.sum_s * params_qkv_stride_in_bytes_ + + qkv_offset * num_heads * d * BYTES_PER_ELEMENT + binfo.bidh * VALID_BYTES_PER_ROW; + } + + // Assemble the final pointer. + qkv_ptr_ += row_offset + idx + col_in_bytes_; + + // Take the CTA offset to modify the sequence length. + actual_seqlen_ -= cta_row_offset; + + // Set the initial seq_len and qkv_offset in case of reinterating + actual_seqlen_init_ = actual_seqlen_; + qkv_ptr_init_ = qkv_ptr_; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + preds[ii] &= col_in_bytes_ < VALID_BYTES_PER_ROW; + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + // The predicates protect against out-of-bound access in rows and cols + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + // Load data from memory. + inline __device__ void load() { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + } + + // Trigger the LDGs. + if (col_in_bytes_ < VALID_BYTES_PER_ROW) { + fmha::pack_predicates(preds_, preds); + fmha::ldg(fetch_, ptrs, preds_); + } else { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + fetch_[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } + } + + // Move the pointer to the next row location. + inline __device__ void move(int const steps = 1) { + qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps; + actual_seqlen_ -= (int)ROWS * steps; + } + + // Move the pointer to the next row location by the offset (not step). + inline __device__ void move_by_offset(int const offset) { + qkv_ptr_ = qkv_ptr_init_ + (int64_t)offset * params_qkv_stride_in_bytes_; + actual_seqlen_ = actual_seqlen_init_ - (int)offset; + } + + // Move the pointer to the next column location + inline __device__ void move_col(int const steps = 1) { + qkv_ptr_ += (int64_t)COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + inline __device__ void reset() { + qkv_ptr_ = qkv_ptr_init_; + actual_seqlen_ = actual_seqlen_init_; + } + + // Rewind the pointer back to previous column location + inline __device__ void rewind_col(int const steps) { + qkv_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + inline __device__ void move_to(int const step) { + qkv_ptr_ = qkv_ptr_init_ + (int64_t)ROWS * params_qkv_stride_in_bytes_ * step; + actual_seqlen_ = actual_seqlen_init_ - (int)ROWS * step; + } + + // Store data to memory. + inline __device__ void store(uint4 const (&data)[LDGS]) { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + char* ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + if (((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen_)) && + col_in_bytes_ < VALID_BYTES_PER_ROW /*TODO: double check*/) { + fmha::stg(ptr, data[ii]); + } + } + } + + // The stride between rows for the QKV matrice. + int64_t params_qkv_stride_in_bytes_; + // The pointer. + char* qkv_ptr_; + char* qkv_ptr_init_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + int col_in_bytes_; + // The sequence length. + int actual_seqlen_; + int actual_seqlen_init_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// We expect the Q/K/V layout to be [B, S, H, D] with variable sequence length support. +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT_, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns (padded, e.g 64). + int COLS, + // The actual number of columns (unpadded, e.g 40) + int VALID_COLS_, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? (not used) + bool HEADS_INTERLEAVED = false, + // The number of matrices (not used) + int NUM_MATS = 1, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION = false> +struct Gmem_tile_q_k_v { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The number of bits/bytes of element + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT_ / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes (without paddings). + enum { VALID_COLS = VALID_COLS_ }; + + // The amount of bytes that are valid per row. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor + // qkv_offset: 0 for Q, 1 for K, 2 for V + template + inline __device__ Gmem_tile_q_k_v(bert::Fused_multihead_attention_params_v2 const& params, + int qkv_offset, Block_info const& binfo, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) { + int seq_offset = 0; + if (qkv_offset == 0) { + // Q tensor + params_q_k_v_stride_in_bytes_ = params.q_stride_in_bytes; + q_k_v_ptr_ = reinterpret_cast(params.q_ptr); + actual_seqlen_ = binfo.actual_q_seqlen; + seq_offset = binfo.sum_s; + } else if (qkv_offset == 1) { + // K tensor + params_q_k_v_stride_in_bytes_ = params.k_stride_in_bytes; + q_k_v_ptr_ = reinterpret_cast(params.k_ptr); + actual_seqlen_ = binfo.actual_kv_seqlen; + seq_offset = binfo.sum_s_kv; + } else if (qkv_offset == 2) { + // V tensor + params_q_k_v_stride_in_bytes_ = params.v_stride_in_bytes; + q_k_v_ptr_ = reinterpret_cast(params.v_ptr); + actual_seqlen_ = binfo.actual_kv_seqlen; + seq_offset = binfo.sum_s_kv; + } + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // We must store the value to update the predicates in "load". + row_ = row; + // Do not load/store if the thread is in the padded area + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; + + // The row offset in the batched GEMM, including the sequence offset. + int64_t row_offset = + (int64_t)(row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_; + // Add the head index. + int64_t idx = binfo.bidh; + + // Assemble the final pointer. + q_k_v_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_; + + // Take the CTA offset to modify the sequence length. + actual_seqlen_ -= cta_row_offset; + + // Set the initial seq_len and qkv_offset in case of reinterating + actual_seqlen_init_ = actual_seqlen_; + q_k_v_ptr_init_ = q_k_v_ptr_; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + preds[ii] &= col_in_bytes_ < VALID_BYTES_PER_ROW; + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = q_k_v_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_q_k_v_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + // The predicates protect against out-of-bound access in rows and cols + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + // Move the pointer to the next row location. + inline __device__ void move(int const steps = 1) { + q_k_v_ptr_ += (int64_t)ROWS * params_q_k_v_stride_in_bytes_ * steps; + actual_seqlen_ -= (int)ROWS * steps; + } + + // Move the pointer to the next row location by the offset (not step). + inline __device__ void move_by_offset(int const offset) { + q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t)offset * params_q_k_v_stride_in_bytes_; + actual_seqlen_ = actual_seqlen_init_ - (int)offset; + } + + // Move the pointer to the next column location + inline __device__ void move_col() { + q_k_v_ptr_ += (int64_t)COLS * (BITS_PER_ELEMENT / 8); + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG; + } + + // Rewind the pointer back to previous column location + inline __device__ void rewind_col(int const steps) { + q_k_v_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + // Move the pointer to the specified step. + inline __device__ void move_to(int const step) { + q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t)ROWS * params_q_k_v_stride_in_bytes_ * step; + actual_seqlen_ = actual_seqlen_init_ - (int)ROWS * step; + } + + inline __device__ void reset() { + q_k_v_ptr_ = q_k_v_ptr_init_; + actual_seqlen_ = actual_seqlen_init_; + } + + // The stride between rows for the Q/K/V matrice. + int64_t params_q_k_v_stride_in_bytes_; + // The pointer. + char* q_k_v_ptr_; + char* q_k_v_ptr_init_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + int64_t col_in_bytes_; + // The sequence length. + int actual_seqlen_; + int actual_seqlen_init_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Shape [B, S, 2, H, D] where S can be variable sequence length. +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT_, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns (padded, e.g 64). + int COLS, + // The actual number of columns (unpadded, e.g 40) + int VALID_COLS_, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? (Not used) + bool HEADS_INTERLEAVED, + // The number of matrices (Not used) + int NUM_MATS = 2, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION = false> +struct Gmem_tile_contiguous_kv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The number of bits/bytes of element + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT_ / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes (without paddings). + enum { VALID_COLS = VALID_COLS_ }; + + // The amount of bytes that are valid per row. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor for bert::Fused_multihead_attention_params_v2 class + template + inline __device__ Gmem_tile_contiguous_kv(bert::Fused_multihead_attention_params_v2 const& params, + int qkv_offset, // q = 0, k = 1, v = 2. + Block_info const& binfo, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : Gmem_tile_contiguous_kv(params.kv_ptr, params.k_stride_in_bytes, params.h_kv, + params.h_q_per_kv, qkv_offset, binfo, tidx, cta_row_offset, + cta_col_offset_in_bytes) {} + + // Ctor. + template + inline __device__ Gmem_tile_contiguous_kv(void* kv_ptr, size_t kv_stride_in_bytes, + int num_kv_heads, int head_group_size, int qkv_offset, + Block_info const& binfo, int tidx, + int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) + : params_kv_stride_in_bytes_(kv_stride_in_bytes), + actual_seqlen_(binfo.actual_kv_seqlen), + kv_ptr_(reinterpret_cast(kv_ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // We must store the value to update the predicates in "load". + row_ = row; + // Do not load/store if the thread is in the padded area + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; + + // The row offset in the batched GEMM. + int64_t row_offset = (int64_t)(row + cta_row_offset) * params_kv_stride_in_bytes_; + // [b, s, 2, h_kv, d]. + int64_t idx = + (binfo.sum_s_kv * 2 + qkv_offset - 1) * num_kv_heads + (binfo.bidh / head_group_size); + + // Assemble the final pointer. + kv_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_; + + // Take the CTA offset to modify the sequence length. + actual_seqlen_ -= cta_row_offset; + + // Set the initial seq_len and qkv_offset in case of reinterating + actual_seqlen_init_ = actual_seqlen_; + kv_ptr_init_ = kv_ptr_; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + preds[ii] &= col_in_bytes_ < VALID_BYTES_PER_ROW; + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = kv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_kv_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + // The predicates protect against out-of-bound access in rows and cols + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + // Load data from memory. + inline __device__ void load() { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_); + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = kv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_kv_stride_in_bytes_; + } + + // Trigger the LDGs. + if (col_in_bytes_ < VALID_BYTES_PER_ROW) { + fmha::pack_predicates(preds_, preds); + fmha::ldg(fetch_, ptrs, preds_); + } else { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + fetch_[ii] = make_uint4(0u, 0u, 0u, 0u); + } + } + } + + // Move the pointer to the next row location. + inline __device__ void move(int const steps = 1) { + kv_ptr_ += (int64_t)ROWS * params_kv_stride_in_bytes_ * steps; + actual_seqlen_ -= (int)ROWS * steps; + } + + // Move the pointer to the next row location by the offset (not step). + inline __device__ void move_by_offset(int const offset) { + kv_ptr_ = kv_ptr_init_ + (int64_t)offset * params_kv_stride_in_bytes_; + actual_seqlen_ = actual_seqlen_init_ - (int)offset; + } + + // Move the pointer to the next column location + inline __device__ void move_col(int const steps = 1) { + kv_ptr_ += (int64_t)COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + inline __device__ void reset() { + kv_ptr_ = kv_ptr_init_; + actual_seqlen_ = actual_seqlen_init_; + } + + // Rewind the pointer back to previous column location + inline __device__ void rewind_col(int const steps) { + kv_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + inline __device__ void move_to(int const step) { + kv_ptr_ = kv_ptr_init_ + (int64_t)ROWS * params_kv_stride_in_bytes_ * step; + actual_seqlen_ = actual_seqlen_init_ - (int)ROWS * step; + } + + // Store data to memory. + inline __device__ void store(uint4 const (&data)[LDGS]) { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + char* ptr = kv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_kv_stride_in_bytes_; + if (((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen_)) && + col_in_bytes_ < VALID_BYTES_PER_ROW /*TODO: double check*/) { + fmha::stg(ptr, data[ii]); + } + } + } + + // The stride between rows for the QKV matrice. + int64_t params_kv_stride_in_bytes_; + // The pointer. + char* kv_ptr_; + char* kv_ptr_init_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + int col_in_bytes_; + // The sequence length. + int actual_seqlen_; + int actual_seqlen_init_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// We expect the paged KV layout to be blocks of indices with shape of [B, 2, Blocks_per_Seq], +// and the indice tells the memory distance to the pool ptr in global memory. + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT_, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns (padded, e.g 64). + int COLS, + // The actual number of columns (unpadded, e.g 40) + int VALID_COLS_, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? (not used) + bool HEADS_INTERLEAVED = false, + // The number of matrices (not used) + int NUM_MATS = 2, + // Is sliding window attention used ? + bool SLIDING_WINDOW_ATTENTION_ = false> +struct Gmem_tile_paged_kv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The number of bits/bytes of element + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT_ / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The valid size of a row in bytes (without paddings). + enum { VALID_COLS = VALID_COLS_ }; + + // The amount of bytes that are valid per row. + enum { VALID_BYTES_PER_ROW = VALID_COLS * BITS_PER_ELEMENT / 8 }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is sliding window attention used ? + enum { SLIDING_WINDOW_ATTENTION = SLIDING_WINDOW_ATTENTION_ }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_paged_kv(bert::Fused_multihead_attention_params_v2 const& params, + int qkv_offset, // q = 0, k = 1, v = 2. + Block_info const& binfo, int tidx, int cta_row_offset = 0, + int cta_col_offset_in_bytes = 0) + : actual_seqlen_(binfo.actual_seqlen), + past_seqlen_(binfo.actual_seqlen - binfo.actual_q_seqlen), + sliding_window_size_(params.sliding_window_size), + paged_kv_log2_block_size_(params.paged_kv_cache.mTokensPerBlockLog2), + paged_kv_block_pool_ptr_(reinterpret_cast(params.paged_kv_cache.mPoolPtr)), + paged_kv_global_block_offsets_(params.paged_kv_cache.mBlockOffsets), + params_kv_block_size_in_bytes_(params.paged_kv_cache.mBytesPerBlock) { + // Handle Paged KV with shape [S, Dh], by offsetting it to the target batch. + int32_t const paged_kv_block_offset = + (binfo.bidb * 2 + qkv_offset - 1) * params.paged_kv_cache.mMaxBlocksPerSeq; + paged_kv_global_block_offsets_ += paged_kv_block_offset; + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // We must store the value to update the predicates in "load". + row_ = row; + // Do not load/store if the thread is in the padded area + col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; + + int64_t kv_stride_in_bytes = + qkv_offset == 1 ? params.k_stride_in_bytes : params.v_stride_in_bytes; + // The head offset. + head_stride_in_bytes_ = (int64_t)(binfo.bidh / params.h_q_per_kv) * kv_stride_in_bytes; + // When V is padded (like MLA), we cannot use VALID_BYTES_PER_ROW + token_stride_in_bytes_ = kv_stride_in_bytes >> paged_kv_log2_block_size_; + + // Take the CTA offset to modify the sequence length. + // Actually we don't need that for flash attention. + actual_seqlen_ -= cta_row_offset; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + // Prepare the predicates. + uint32_t preds[LDGS]; + // Prepare the load pointers. + void const* ptrs[LDGS]; + + // Offset for the new paged kv pointer. + uint64_t const head_col_in_bytes = head_stride_in_bytes_ + col_in_bytes_; + +// Update paged_kv ptr for each LDG (reuse is possible). +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + int row_idx = row_ + ii * (int)ROWS_PER_LDG; + int paged_kv_block_idx = (row_idx >> paged_kv_log2_block_size_); + char const* local_kv_ptr = reinterpret_cast( + paged_kv_block_pool_ptr_ + + params_kv_block_size_in_bytes_ * paged_kv_global_block_offsets_[paged_kv_block_idx]); + + // Predicates. + // TODO: do we need to make sure row_idx < ROWS ? + preds[ii] = row_idx < actual_seqlen_; + preds[ii] &= col_in_bytes_ < VALID_BYTES_PER_ROW; + + // Pointers. + int row_idx_in_block = row_idx & ((1 << paged_kv_log2_block_size_) - 1); + ptrs[ii] = + local_kv_ptr + head_col_in_bytes + (int64_t)row_idx_in_block * token_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + // The predicates protect against out-of-bound access in rows and cols + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + // Move the pointer to the next row location. + inline __device__ void move() { row_ += ROWS; } + + // Move the pointer to the next row location by the offset (not step). + inline __device__ void move_by_offset(int const offset) { row_ += offset; } + + // Move the pointer to the next column location + inline __device__ void move_col() { col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG; } + + // Rewind the pointer back to previous column location + inline __device__ void rewind_col(int const steps) { + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps; + } + + // The stride between rows for the KV matrice. + int64_t params_kv_block_size_in_bytes_; + // The paged cache pool pointer. + char* paged_kv_block_pool_ptr_; + // The paged block offsets. + int32_t* paged_kv_global_block_offsets_; + // The paged block size. + int paged_kv_log2_block_size_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + int64_t col_in_bytes_; + // Keep track of the head offset. + int64_t head_stride_in_bytes_; + // // for DeepSeek MLA, the stride of V tokens != VALID_BYTES_PER_ROW + int32_t token_stride_in_bytes_; + // The sequence length. + int actual_seqlen_; + // The past sequence length (kv_seqlen - q_seqlen) considering chunked context. + int past_seqlen_; + // The sliding attention window size. + int sliding_window_size_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q loaded by this tile. + int ROWS_, + // The number of columns. + int COLS, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? + bool HEADS_INTERLEAVED, + // The number of matrices + int NUM_MATS = 1> +struct Gmem_tile_q_kv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The padded to the next power of 2 number of columns + enum { COLS_PADDED = Next_power_of_two::VALUE }; + + // The padded size of a row in bytes. + enum { BYTES_PER_ROW_PADDED = COLS_PADDED * BITS_PER_ELEMENT / 8 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a padded "row" of the matrix. + enum { THREADS_PER_ROW_PADDED = BYTES_PER_ROW_PADDED / BYTES_PER_LDG }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW_PADDED }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_q_kv(Params const& params, int offset, Block_info const& binfo, + int tidx, int cta_row_offset = 0) + : params_stride_in_bytes_(params.stride_in_bytes), + actual_seqlen_(binfo.actual_seqlen), + ptr_(reinterpret_cast(params.ptr)) { + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW_PADDED; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW_PADDED; + + // We must store the value to update the predicates in "load". + row_ = row; + // Mask for predicate if the channels are in the padded area + int const bytes_per_row_non_padded = params.d * BITS_PER_ELEMENT / 8; + mask_ = col < bytes_per_row_non_padded / BYTES_PER_LDG; + + // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + int64_t row_offset = (int64_t)(row + cta_row_offset) * params.stride_in_bytes; + // Add the block index. + int64_t idx; + if (HEADS_INTERLEAVED) { + idx = binfo.bidx * NUM_MATS + offset; + } else { + idx = (binfo.sum_s * NUM_MATS + offset) * params.h + binfo.bidh; + } + // Assemble the final pointer. + ptr_ += row_offset + idx * bytes_per_row_non_padded + col * BYTES_PER_LDG; + + // Take the CTA offset to modify the sequence length. + actual_seqlen_ -= cta_row_offset; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + uint32_t preds[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + preds[ii] = (row_ + ii * (int)ROWS_PER_LDG < min((int)ROWS, actual_seqlen_)) && mask_; + } + + // Prepare the load pointers. + void const* ptrs[LDGS]; +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + ptrs[ii] = ptr_ + (int64_t)ii * ROWS_PER_LDG * params_stride_in_bytes_; + } + + // Trigger LDGSTS or the LDGs. + Ldgsts_helper::load(this, smem_tile, ptrs, preds); + } + + inline __device__ void move(int const steps = 1) { + ptr_ += (int64_t)ROWS * params_stride_in_bytes_ * steps; + actual_seqlen_ -= (int)ROWS * steps; + } + + // Store data to memory. + inline __device__ void store(uint4 const (&data)[LDGS]) { +#pragma unroll + for (int ii = 0; ii < LDGS; ++ii) { + char* ptr = ptr_ + (int64_t)ii * ROWS_PER_LDG * params_stride_in_bytes_; + if ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen_)) { + fmha::stg(ptr, data[ii]); + } + } + } + + // The stride between rows for the matrix. + int64_t params_stride_in_bytes_; + // The pointer. + char* ptr_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row and col the thread is processing as we move the tile. + int row_; + // Keep track of predicate state that depends only on the initialization state. + int mask_; + // The sequence length. + int actual_seqlen_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns. + int COLS, + // Do we use LDGSTS? + bool USE_LDGSTS_> +struct Gmem_tile_qkv_interleaved { + // The vectorization width for NC/32HW32. + enum { VEC = 32 }; + + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = VEC * BITS_PER_ELEMENT / 8 }; + + // DEBUG. + static_assert(BYTES_PER_ROW == 32, ""); + + // END OF DEBUG. + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // DEBUG. + static_assert(THREADS_PER_ROW == 2, ""); + + // END OF DEBUG. + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of slices. It is either 1 for DIM_PER_HEAD == 32 and 2 for DIM_PER_HEAD == 64. + enum { NUM_SLICES = COLS / VEC }; + + // DEBUG. + static_assert(NUM_SLICES == 1 || NUM_SLICES == 2, ""); + + // END OF DEBUG. + + // The number of rows in a slice. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Make sure we use a single register to store predicates. + static_assert(PRED_REGS == 1, ""); + + // Do we use LDGSTS on Ampere? + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // Ctor. + template + inline __device__ Gmem_tile_qkv_interleaved(Params const& params, int qkv_select, + Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : actual_seqlen_(block_info.actual_seqlen - cta_row_offset), + total_(params.q_stride_in_bytes), + kv_ptr_(reinterpret_cast(params.qkv_ptr)) { + int bidh = block_info.bidh; + int sum_s = block_info.sum_s; + + // We must keep track of the row to repack predicates in load. + row_ = tidx / THREADS_PER_ROW; + // The column. + int col = tidx % THREADS_PER_ROW; + + // h is N + // d is H + // we get the data in as: 3 x h x (d/32) x total x 32 (think 3 x h x (d/32) + // x b x s x 32) + + // Loading qkv: ignore slice for now. + int qkv_offset = qkv_select * params.h * NUM_SLICES * total_; + // bidh * GROUPS * B * S + b * S. + int block_offset = bidh * NUM_SLICES * total_ + sum_s; + // The row offset. + int row_offset = (qkv_offset + block_offset + cta_row_offset) * BYTES_PER_ROW; + + // That's the pointer to load from (see "load"). + kv_ptr_ += row_offset + col * BYTES_PER_LDG; + + init_actual_seqlen_ = actual_seqlen_; + init_kv_ptr_ = kv_ptr_; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) { + if (!USE_LDGSTS) { + smem_tile.store(fetch_); + } + } + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + void const* ptrs[LDGS]; + uint32_t preds[LDGS]; + +// We precompute slice offsets and predicates +#pragma unroll + for (int ii = 0; ii < LDGS; ii++) { + // the next row + int row_i = row_ + ii * ROWS_PER_LDG; + + // Decompose the current row in slice and original row + int slice = row_i / ROWS; + // The position in the slice. + int row_in_slice = row_i % ROWS; + + // Update the predicate. + preds[ii] = row_in_slice < min(actual_seqlen_, ROWS); + // Compute the pointer. + ptrs[ii] = &kv_ptr_[(slice * total_ + row_in_slice) * BYTES_PER_ROW]; + } + + // Update the predicate register. + fmha::pack_predicates(preds_, preds); + + // Trigger the loads. + if (USE_LDGSTS) { + smem_tile.store(ptrs, preds_); + } else { + fmha::ldg(fetch_, ptrs, preds_); + } + } + + // Move the pointer to the next location. + inline __device__ void move(int const steps = 1) { + kv_ptr_ += (int64_t)ROWS * BYTES_PER_ROW * steps; + actual_seqlen_ -= ROWS * steps; + } + + // Reset to the initial location. + inline __device__ void reset() { + kv_ptr_ = init_kv_ptr_; + actual_seqlen_ = init_actual_seqlen_; + } + + // The pointer. + char const* kv_ptr_; + char const* init_kv_ptr_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // keep track of the row the thread is processing as we move the tile + int row_; + // The sequence length. + int actual_seqlen_; + int init_actual_seqlen_; + // The number of rows per slice?? + int total_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace v2 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/arrive_wait.h b/csrc/fmha_v2/fmha/hopper/arrive_wait.h new file mode 100644 index 0000000000..6448d82607 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/arrive_wait.h @@ -0,0 +1,396 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +// CP ASYNC FEATURES /////////////////////////////////////////////////////////////////////////////// +#if !defined(CUDA_CP_ASYNC_SUPPORTED) && \ + ((__CUDACC_VER_MAJOR__ >= 11) || \ + ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 2))) +#define CUDA_CP_ASYNC_SUPPORTED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_ENABLED) && (CUDA_CP_ASYNC_SUPPORTED) +#define CUDA_CP_ASYNC_ENABLED 1 +#endif + +#if CUDA_CP_ASYNC_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +#define CUDA_CP_ASYNC_ACTIVATED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_GROUP_POLICY_SUPPORTED) && (CUDA_CP_ASYNC_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ >= 11) +#define CUDA_CP_ASYNC_GROUP_POLICY_SUPPORTED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_GROUP_POLICY_ENABLED) && (CUDA_CP_ASYNC_GROUP_POLICY_SUPPORTED) +#define CUDA_CP_ASYNC_GROUP_POLICY_ENABLED 1 +#endif + +#if CUDA_CP_ASYNC_GROUP_POLICY_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +#define CUDA_CP_ASYNC_GROUP_POLICY_ACTIVATED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_MBARRIER_ARRIVE_SUPPORTED) && (CUDA_CP_ASYNC_SUPPORTED) && \ + (__CUDACC_VER_MAJOR__ >= 11) +#define CUDA_CP_ASYNC_MBARRIER_ARRIVE_SUPPORTED 1 +#endif + +#if !defined(CUDA_CP_ASYNC_MBARRIER_ARRIVE_ENABLED) && (CUDA_CP_ASYNC_MBARRIER_ARRIVE_SUPPORTED) +#define CUDA_CP_ASYNC_MBARRIER_ARRIVE_ENABLED 1 +#endif + +#if (CUDA_CP_ASYNC_MBARRIER_ARRIVE_ENABLED) && (__CUDA_ARCH__ >= 800) +#define CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED 1 +#endif + +#if (CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED) && (CUDACC_VERSION >= 111) +#define CUDA_CP_ASYNC_MBARRIER_WAIT_ACTIVATED 1 +#endif + +#if !defined(FMHA_PTX_MBARRIER_TRYWAIT_NOSLEEP_INTERNAL_SUPPORT_ENABLED) +#define FMHA_PTX_MBARRIER_TRYWAIT_NOSLEEP_INTERNAL_SUPPORT_ENABLED 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +inline __device__ void named_barrier_arrive(uint32_t BARRIER_ID, uint32_t NUM_THREADS) { + if (NUM_THREADS > 1) { + asm volatile("bar.arrive %0, %1;" : : "r"(BARRIER_ID), "r"(NUM_THREADS)); + } +} + +inline __device__ void named_barrier_wait(uint32_t BARRIER_ID, uint32_t NUM_THREADS) { + if (NUM_THREADS > 1) { + asm volatile("bar.sync %0, %1;" ::"r"(BARRIER_ID), "r"(NUM_THREADS)); + } +} + +// it is executed per thread, i.e., each thread can call and init a barrier. +// need a bar.sync after using it. +inline __device__ void bar_create(void* bar_ptr, int init_count) { + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + + asm volatile( + "{\n\t" +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + "mbarrier.init.shared.b64 [%1], %0; \n\t" +#else + ".reg .s32 negCnt, count, expectedCount;\n\t" + ".reg .s64 comboCnt; \n\t" + "neg.s32 negCnt, %0;\n\t " + "and.b32 count, negCnt, 0x7fffffff; \n\t" + "and.b32 expectedCount, negCnt, 0x3fffffff; \n\t" + "mov.b64 comboCnt, {expectedCount, count}; \n\t" + "st.shared.s64 [%1], comboCnt; \n\t" +#endif + "}" + : + : "r"(init_count), "r"(smem_ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Arrive_wait { + public: + inline __device__ Arrive_wait() { bar_base_ = NULL; } + + inline __device__ Arrive_wait(uint64_t* bar_base, int id = 0) { + bar_base_ = bar_base; + id_ = id; + } + + inline __device__ uint64_t* get_bar_addr(int32_t id) { + return reinterpret_cast(bar_base_ + id); + } + + inline __device__ int bar_peek(int id, unsigned int bar_phase) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + uint32_t result32; +#if FMHA_PTX_MBARRIER_TRYWAIT_NOSLEEP_INTERNAL_SUPPORT_ENABLED + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + "mbarrier.try_wait.parity.nosleep.shared.b64 P3, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P3; \n\t" + "}" + : "=r"(result32) + : "r"(smem_ptr), "r"(bar_phase)); +#else + // public ptx default heruistic generate SASS equal to with .nosleep in internal ptx + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + "mbarrier.try_wait.parity.shared.b64 P3, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P3; \n\t" + "}" + : "=r"(result32) + : "r"(smem_ptr), "r"(bar_phase)); +#endif + return result32; +#else + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned int output_phase = (bar_ptr[0] >> 63) & 1; + + return output_phase != bar_phase; +#endif + } + + inline __device__ int bar_peek(int id, unsigned int bar_phase, int pred) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + uint32_t result32; +#if FMHA_PTX_MBARRIER_TRYWAIT_NOSLEEP_INTERNAL_SUPPORT_ENABLED + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + ".reg .pred P2;\n\t" + "setp.eq.u32 P2, %3, 1;\n\t" + "@P2 mbarrier.try_wait.parity.nosleep.shared.b64 P3, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P3; \n\t" + "}" + : "=r"(result32) + : "r"(smem_ptr), "r"(bar_phase), "r"(pred)); +#else + // public ptx default heruistic generate SASS equal to with .nosleep in internal ptx + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + ".reg .pred P2;\n\t" + "setp.eq.u32 P2, %3, 1;\n\t" + "@P2 mbarrier.try_wait.parity.shared.b64 P3, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P3; \n\t" + "}" + : "=r"(result32) + : "r"(smem_ptr), "r"(bar_phase), "r"(pred)); +#endif + return result32; +#else + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned int output_phase = (bar_ptr[0] >> 63) & 1; + + return output_phase != bar_phase; +#endif + } + + inline __device__ void bar_wait(int id, unsigned int bar_phase) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + uint32_t large_val = 0x989680; + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" + "LAB_WAIT: \n\t" + //"mbarrier.try_wait.parity.b64 P3, [%0], %1; \n\t" + "mbarrier.try_wait.parity.shared.b64 P3, [%0], %1, %2; \n\t" + "@P3 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_ptr), "r"(bar_phase), "r"(large_val)); +#else + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + + asm volatile( + "{\n\t" + ".reg .pred P3; \n\t" +#ifdef CUDA_CP_ASYNC_MBARRIER_WAIT_ACTIVATED + "mbarrier.test_wait.parity.shared.b64 P3, [%0], %1;\n\t" +#else + ".reg .s32 high, low; \n\t" + ".reg .u32 currentPhase; \n\t" + "ld.volatile.shared.v2.s32 { low, high }, [%0]; \n\t" + "shr.u32 currentPhase, high, 31; \n\t" + "setp.ne.u32 P3, currentPhase, %1; \n\t" +#endif + "@P3 bra.uni DONE; \n\t" + "LAB_WAIT: \n\t" +#ifdef CUDA_CP_ASYNC_MBARRIER_WAIT_ACTIVATED + "mbarrier.test_wait.parity.shared.b64 P3, [%0], %1;\n\t" +#else + "ld.volatile.shared.v2.s32 { low, high }, [%0]; \n\t" + "shr.u32 currentPhase, high, 31; \n\t" + "setp.ne.u32 P3, currentPhase, %1; \n\t" +#endif + "@P3 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_ptr), "r"(bar_phase)); +#endif + } + + // Set the expected_transaction_count and add 1 arrive count (1 transaction = 1 Byte) + // This PTX maps to SYNCS.ARRIVES.TRANS64.A1TR. + inline __device__ void bar_arrive_set_transactioncnt(int id, int expected_copy_bytes) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + asm volatile( + "{\n\t" + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1; \n\t" + "}" + : + : "r"(smem_ptr), "r"(expected_copy_bytes)); +#endif + } + + // Set the expected_transaction_count and add 1 arrive count (1 transaction = 1 Byte) + // This PTX maps to SYNCS.ARRIVES.TRANS64.A1TR. + inline __device__ void bar_arrive_set_transactioncnt(int id, int expected_copy_bytes, + uint32_t pred) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1; \n\t" + "}" + : + : "r"(smem_ptr), "r"(expected_copy_bytes), "r"(pred)); +#endif + } + + // Sends barrier arrive notification to DSMEM + // Note this uses a slightly different syntax compared to normal arrive + // NOTE : Caller has to ensure that set_bar_base_dsmem has been called prior to using this + // This is done as a compiler optimizations (since set barrier base is independent) + inline __device__ void bar_arrive_dsmem(int const& id) { +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + // TODO : check with PTX team on setctarank (currently emitting errors) + // asm volatile("{\n\t" + //"setctarank.shared.u32 %0, %1, %2;\n\t" + //"}" + // : "=r"(dst_ptr) : "r"(smem_ptr), "r"(cta_id)); + + asm volatile( + "{\n\t" + "mbarrier.arrive.b64 _, [%0];\n\t" + "}" + : + : "l"(bar_ptr)); +#endif + } + + // Just a predicated version of the above function + // Manually inlining it - since the compiler generates BRA instructions at the moment + // NOTE : Caller has to ensure that set_bar_base_dsmem has been called prior to using this + // This is done as a compiler optimizations (since set barrier base is independent) + inline __device__ void bar_arrive_dsmem(int const& id, uint32_t const& pred) { +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + asm volatile( + "{\n\t" + " .reg .pred p;\n\t" + " .reg .s64 addr;\n\t" + " .reg .b64 tmp;\n\t" + " setp.eq.u32 p, %2, 1;\n\t" + " mul.wide.s32 tmp, %0, 8;\n\t" + " add.s64 addr, tmp, %1;\n\t" + "@p mbarrier.arrive.b64 _, [addr];\n\t" + "}" + : + : "r"(id), "l"(bar_base_), "r"(pred)); +#endif + } + + // Sets up the base address for arrival with the correct ctaid in cga + inline __device__ void set_bar_base_dsmem(uint32_t const& cta_id) { + bar_base_ = reinterpret_cast( + ((unsigned long long int)bar_base_ & 0xFFFFFFFFF0FFFFFFULL) + (cta_id << 24)); + } + + inline __device__ void bar_arrive_normal(int id, bool flag = true) { +#if CUDA_CP_ASYNC_ACTIVATED && !(CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED) + asm("membar.cta;"); +#endif + + // to make distance for the dependence between atoms.arrive and shfl + if (flag == true) { + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + + asm volatile( + "{\n\t" + ".reg .b64 state; \n\t" + "mbarrier.arrive.shared.b64 state, [%0];\n\t" + "}" + : + : "r"(smem_ptr)); + +#elif CUDA_CP_ASYNC_ACTIVATED + + asm volatile( + "{\n\t" + ".reg .b64 state; \n\t" + "atom.shared.arrive.b64 state, [%0];" + "}" + : + : "r"(smem_ptr)); +#endif + } + } + + inline __device__ void bar_arrive_ldgsts(int id) { + uint64_t* bar_ptr = reinterpret_cast(bar_base_ + id); + unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); + +#if CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED + asm volatile("cp.async.mbarrier.arrive.noinc.shared.b64 [%0];" : : "r"(smem_ptr)); +#elif CUDA_CP_ASYNC_ACTIVATED + asm volatile("cp.async.arrive.shared.b64 [%0];" : : "r"(smem_ptr)); +#endif + } + + inline __device__ uint64_t* bar_base() { return bar_base_; } + + private: + // smem barrier base pointer + uint64_t* bar_base_; + // barrier id + int id_; +}; + +// Set the expected_transaction_count and add 1 arrive count (1 transaction = 1 Byte) +// This PTX maps to SYNCS.ARRIVES.TRANS64.A1TR. +inline __device__ void bar_arrive_set_transactioncnt(unsigned smem_ptr, + unsigned expected_copy_bytes) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "{\n\t" + "mbarrier.arrive.expect_copy.shared.b64 _, [%0], %1; \n\t" + "}" + : + : "r"(smem_ptr), "r"(expected_copy_bytes)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/compute_tile.h b/csrc/fmha_v2/fmha/hopper/compute_tile.h new file mode 100644 index 0000000000..e08c36fc7f --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/compute_tile.h @@ -0,0 +1,503 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include + +namespace fmha { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_tile_with_gmma {}; + +/* +compute tile used when both operands are coming from SMEM +*/ +template +struct Compute_tile_with_gmma { + static constexpr int NUM_KBLOCKS = Smem_tile_b::BUFFERS_PER_TILE / Cta_tile::WARPS_K; + static_assert(NUM_KBLOCKS * Cta_tile::WARPS_K == Smem_tile_b::BUFFERS_PER_TILE); + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // desc for A and B should have the same strategy + static_assert(Smem_tile_a::Gmma_descriptor::GMMA_DESC_SIZE_PER_GROUP == + Smem_tile_b::Gmma_descriptor::GMMA_DESC_SIZE_PER_GROUP, + "GMMA desc for A and B should have the same strategy."); + + // The number of MMAs. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + enum { MMAS_K = Mma_tile::MMAS_K }; + + // Ctor. + inline __device__ Compute_tile_with_gmma() {} + + // Ctor, that helps set the gmma descs to support different buffer index as the start address. + inline __device__ Compute_tile_with_gmma(void* a_smem_, void* b_smem_) + : Compute_tile_with_gmma(__nvvm_get_smem_pointer(a_smem_), __nvvm_get_smem_pointer(b_smem_)) { + } + + inline __device__ Compute_tile_with_gmma(uint32_t a_smem_base, uint32_t b_smem_base) + : a_smem_base_(a_smem_base), b_smem_base_(b_smem_base) { + // We always start at buffer 0. + uint32_t a_smem = a_smem_base_; + uint32_t b_smem = b_smem_base_; + +#pragma unroll + for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) { + gmma_desc_a_[mma_m_idx].set_smem_pointer(a_smem + + mma_m_idx * Smem_tile_a::GMMA_GROUP_SMEM_DISTANCE); + // We take the number of buffers directly from the Smem_tile. If we have only one buffer, the + // return offset is 0. + gmma_desc_a_[mma_m_idx].set_max_descriptor_0(Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB * + (Smem_tile_a::BUFFERS_PER_TILE - 1)); + } + +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + gmma_desc_b_[mma_n_idx].set_smem_pointer(b_smem + + mma_n_idx * Smem_tile_b::GMMA_GROUP_SMEM_DISTANCE); + gmma_desc_b_[mma_n_idx].set_max_descriptor_0(Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB * + (Smem_tile_b::BUFFERS_PER_TILE - 1)); + } + } + + // move the gmme desc by N buffers. + // Something nice to have if we have persistent kernels. + inline __device__ void increment_N_gmma_desc_group(int N) { +#pragma unroll + for (int idx = 0; idx < Smem_tile_a::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) { + uint64_t temp_desc = gmma_desc_a_[mma_m_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + tmp.x = (tmp.x & 0xFFFF0000) + (a_smem_base_ / 16) + + mma_m_idx * Smem_tile_a::GMMA_GROUP_SMEM_DISTANCE / 16 + + N * Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + gmma_desc_a_[mma_m_idx].set_descriptor(idx, temp_desc); + } + +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + tmp.x = + (tmp.x & 0xFFFF0000) + (b_smem_base_ / 16) + N * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // Clear the accumulators. It does nothing as we have a special flag for GMMA. + inline __device__ void clear() { fmha::clear(acc_); } + + // smarter way of increment a group of gmma desc. + // if one of them need to be reset to the first ldgsts buffer + // it is very likely (currently guaranteed) that all of them need to be reset to the first + // ldgsts buffer. + // we do this to save the usage of uniform register. Otherwise, kernel with larger M could not + // achieve sol. + inline __device__ void increment_gmma_desc_group() { + bool reset_buffer_a = + gmma_desc_a_[0].get_descriptor(0) >= gmma_desc_a_[0].get_max_descriptor_0(); + bool reset_buffer_b = + gmma_desc_b_[0].get_descriptor(0) >= gmma_desc_b_[0].get_max_descriptor_0(); + +#pragma unroll + for (int idx = 0; idx < Smem_tile_a::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) { + uint64_t temp_desc = gmma_desc_a_[mma_m_idx].get_descriptor(idx); + // smem start address is in lower 32bits + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer_a) { + tmp.x -= (Smem_tile_a::BUFFERS_PER_TILE - 1) * Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + } + + gmma_desc_a_[mma_m_idx].set_descriptor(idx, temp_desc); + } + +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer_b) { + tmp.x -= (Smem_tile_b::BUFFERS_PER_TILE - 1) * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // smarter way of increment a group of gmma desc. + // if one of them need to be reset to the first ldgsts buffer + // it is very likely (currently guaranteed) that all of them need to be reset to the first + // ldgsts buffer. + // we do this to save the usage of uniform register. Otherwise, kernel with larger M could not + // achieve sol. + inline __device__ void increment_gmma_desc_a_group() { + bool reset_buffer = gmma_desc_a_[0].get_descriptor(0) >= gmma_desc_a_[0].get_max_descriptor_0(); + +#pragma unroll + for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) { + uint64_t temp_desc = gmma_desc_a_[mma_m_idx].get_descriptor(idx); + // smem start address is in lower 32bits + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer) { + tmp.x -= (Smem_tile_a::BUFFERS_PER_TILE - 1) * Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_a::BYTES_PER_BUFFER_NO_4LSB; + } + gmma_desc_a_[mma_m_idx].set_descriptor(idx, temp_desc); + } + } + } + + // smarter way of increment a group of gmma desc. + // if one of them need to be reset to the first ldgsts buffer + // it is very likely (currently guaranteed) that all of them need to be reset to the first + // ldgsts buffer. + // we do this to save the usage of uniform register. Otherwise, kernel with larger M could not + // achieve sol. + template + inline __device__ void increment_gmma_desc_b_group(int N = 1) { + bool reset_buffer = + RESET_CHECK && gmma_desc_b_[0].get_descriptor(0) >= gmma_desc_b_[0].get_max_descriptor_0(); + +#pragma unroll + for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer) { + tmp.x -= (Smem_tile_b::BUFFERS_PER_TILE - 1) * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // Compute. + // last of group indicates it is the last GMMA with a GMMA group. So the GSB should be updated + // last of kblock indicates it is the last GMMA with kblock. so desc will be updated accordingly + inline __device__ void compute(int ki, bool last_of_group = false, bool last_of_kblock = false) { +#pragma unroll + for (int mmas_m_idx = 0; mmas_m_idx < MMAS_M; ++mmas_m_idx) { +#pragma unroll + for (int mmas_n_idx = 0; mmas_n_idx < MMAS_N; ++mmas_n_idx) { + // weird code to use SEL to avoid reg spill + typename Smem_tile_a::Gmma_descriptor::Single_desc single_desc_a; + typename Smem_tile_b::Gmma_descriptor::Single_desc single_desc_b; + + single_desc_a.set(gmma_desc_a_[mmas_m_idx].get_descriptor(ki)); + single_desc_b.set(gmma_desc_b_[mmas_n_idx].get_descriptor(ki)); + + if (mmas_n_idx == (MMAS_N - 1)) { + // update desc for A + gmma_desc_a_[mmas_m_idx].increment_single_descriptor(last_of_kblock); + } + if (mmas_m_idx == (MMAS_M - 1)) { + // update desc for B + gmma_desc_b_[mmas_n_idx].increment_single_descriptor(last_of_kblock); + } + + if ((last_of_group == true) && (mmas_m_idx == (MMAS_M - 1)) && + (mmas_n_idx == (MMAS_N - 1))) { + // increment the scoreboard + acc_[mmas_m_idx][mmas_n_idx].template mma(single_desc_a, single_desc_b); + } else { + acc_[mmas_m_idx][mmas_n_idx].template mma(single_desc_a, single_desc_b); + } + } // for (mmas_n_idx) + } // for (mmas_m_idx) + } + + // Load from shared memory. For GMMA where both operand comes from SMEM, this does nothing + inline __device__ void load(Smem_tile_a& smem_a, Smem_tile_b& smem_b, int ki, + bool first = false) {} + + // The accumulators. + Fragment_accumulator acc_[MMAS_M][MMAS_N]; + + // one descriptor group per stage, different GMMAs may or maynot share descriptor group + // each descriptor group holds all the descriptors for the entire kblock + + // The descriptor to load A. + typename Smem_tile_a::Gmma_descriptor gmma_desc_a_[MMAS_M]; + // The descriptor to load B. + typename Smem_tile_b::Gmma_descriptor gmma_desc_b_[MMAS_N]; + uint32_t a_smem_base_, b_smem_base_; +}; + +/* +compute tile used when A is from RF, B is from SMEM +*/ +template +struct Compute_tile_with_gmma { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The fragment for holding A. + using Fragment = Fragment_a; + + // static_assert(Cta_tile::K == 128); + // static_assert(Mma_tile::K_PER_MMA_PER_CTA == 64 ); + // pstatic_assert(NUM_KBLOCKS == 384 / 64); + static constexpr int NUM_KBLOCKS = Smem_tile_b::BUFFERS_PER_TILE / Cta_tile::WARPS_K; + // static_assert(NUM_KBLOCKS * Cta_tile::WARPS_K == Smem_tile_b::BUFFERS_PER_TILE); + + // desc for A and B should have the same strategy + static_assert(Smem_tile_a::Gmma_descriptor::GMMA_DESC_SIZE_PER_GROUP == + Smem_tile_b::Gmma_descriptor::GMMA_DESC_SIZE_PER_GROUP, + "GMMA desc for A and B should have the same strategy."); + + // The number of MMAs. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // TODO + enum { MMAS_K = Mma_tile::MMAS_K * Cta_tile::WARPS_K }; + + // Ctor. + inline __device__ Compute_tile_with_gmma() {} + + // Ctor, that helps set the gmma descs + inline __device__ Compute_tile_with_gmma(void* a_smem_, void* b_smem_) + : Compute_tile_with_gmma(__nvvm_get_smem_pointer(a_smem_), __nvvm_get_smem_pointer(b_smem_)) { + } + + inline __device__ Compute_tile_with_gmma(uint32_t, uint32_t b_smem_base) + : b_smem_base_(b_smem_base) { + // We always start at buffer 0 and take the number of buffers from the Smem_tile, as above. + uint32_t b_smem = b_smem_base_; +// do not need to set desc for matrix A +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + gmma_desc_b_[mma_n_idx].set_smem_pointer(b_smem + + mma_n_idx * Smem_tile_b::GMMA_GROUP_SMEM_DISTANCE); + gmma_desc_b_[mma_n_idx].set_max_descriptor_0(Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB * + (Smem_tile_b::BUFFERS_PER_TILE - 1)); + } + } + + // move the gmme desc by N buffers. + // Something nice to have if we have persistent kernels. + inline __device__ void increment_N_gmma_desc_group(int N) { +#pragma unroll + for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + tmp.x = + (tmp.x & 0xFFFF0000) + (b_smem_base_ / 16) + (N)*Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // Clear the accumulators. It does nothing as we have a special flag for GMMA. + inline __device__ void clear() { fmha::clear(acc_); } + + // smarter way of increment a group of gmma desc. + // if one of them need to be reset to the first ldgsts buffer + // it is very likely (currently guaranteed) that all of them need to be reset to the first + // ldgsts buffer. + // we do this to save the usage of uniform register. Otherwise, kernel with larger M could not + // achieve sol. + + template + inline __device__ void increment_gmma_desc_group(int N = 1) { + bool reset_buffer = + RESET_CHECK && gmma_desc_b_[0].get_descriptor(0) >= gmma_desc_b_[0].get_max_descriptor_0(); + +#pragma unroll + for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { +#pragma unroll + for (int mma_n_idx = 0; mma_n_idx < MMAS_N; ++mma_n_idx) { + uint64_t temp_desc = gmma_desc_b_[mma_n_idx].get_descriptor(idx); + int2& tmp = reinterpret_cast(temp_desc); + if (reset_buffer) { + tmp.x -= (Smem_tile_b::BUFFERS_PER_TILE - 1) * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } else { + tmp.x += Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB; + } + gmma_desc_b_[mma_n_idx].set_descriptor(idx, temp_desc); + } + } + } + + // Compute. + // last of group indicates it is the last GMMA with a GMMA group. So the GSB should be updated + // last of kblock indicates it is the last GMMA with kblock. so desc will be updated accordingly + inline __device__ void compute(int ki, bool last_of_group = false, bool last_of_kblock = false) { +#pragma unroll + for (int mmas_m_idx = 0; mmas_m_idx < MMAS_M; ++mmas_m_idx) { +#pragma unroll + for (int mmas_n_idx = 0; mmas_n_idx < MMAS_N; ++mmas_n_idx) { + // weird code to use SEL to avoid reg spill + typename Smem_tile_b::Gmma_descriptor::Single_desc single_desc_b; + + single_desc_b.set(gmma_desc_b_[mmas_n_idx].get_descriptor(ki)); + + if (mmas_m_idx == (MMAS_M - 1)) { + // update desc for B + gmma_desc_b_[mmas_n_idx].increment_single_descriptor(last_of_kblock); + } + + if ((last_of_group == true) && (mmas_m_idx == (MMAS_M - 1)) && + (mmas_n_idx == (MMAS_N - 1))) { + // increment the scoreboard + acc_[mmas_m_idx][mmas_n_idx].template mma(a_[mmas_m_idx], single_desc_b); + } else { + acc_[mmas_m_idx][mmas_n_idx].template mma(a_[mmas_m_idx], single_desc_b); + } + } // for (mmas_n_idx) + } // for (mmas_m_idx) + } + + template + inline __device__ void compute_incta_splitk(Fragment const (&frag_a)[K][1], int const warp_k) { + if (Smem_tile_b::Gmma_descriptor::TRANS_MODE == Gmma_descriptor_transpose::NOTRANS) { + // In this case, the K dimension is the leading dimension, so we need to set the smem + // locations correctly for each Warp in K. + + // The number of elements in K per group. + constexpr int ELTS_PER_KGROUP = Smem_tile_b::BYTES_PER_ROW / sizeof(typename Traits::B_type); + // The number of MMAS to perform before incrementing by the group stride. + constexpr int MMAS_K_PER_GROUP = ELTS_PER_KGROUP / Traits::GMMA_K; + // The number of MMAS a k-warp performs. + constexpr int MMAS_K_PER_WARP = Mma_tile::MMAS_K; + + int const group_offset = warp_k * MMAS_K_PER_WARP; + // Initialize the descriptor + int gi = group_offset / MMAS_K_PER_GROUP; + int ii = group_offset % MMAS_K_PER_GROUP; + + int BYTES_OFFSET_NO_4LSB = gi * Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB + + ii * Smem_tile_b::Gmma_descriptor::BYTES_PER_DESC_NO_4LSB; + + uint64_t desc_b = gmma_desc_b_[0].get_descriptor(0); + int2& desc_b_view = reinterpret_cast(desc_b); + desc_b_view.x += BYTES_OFFSET_NO_4LSB; + + typename Smem_tile_b::Gmma_descriptor::Single_desc single_desc_b; + single_desc_b.set(desc_b); +#pragma unroll + for (int ki = 0; ki < MMAS_K_PER_WARP - 1; ki++) { + acc_[0][0].template mma(frag_a[ki][0], single_desc_b); + + // Increment the descriptor for the next kblock. + int const ki_next = group_offset + ki + 1; + // Update descriptor for next GMMA. + if (ki_next % MMAS_K_PER_GROUP == 0) { + desc_b_view.x += Smem_tile_b::BYTES_PER_BUFFER_NO_4LSB - + Smem_tile_b::Gmma_descriptor::BYTES_DESC_INC_BOUNDARY_NO_4LSB; + } else { + desc_b_view.x += Smem_tile_b::Gmma_descriptor::BYTES_PER_DESC_NO_4LSB; + } + single_desc_b.set(desc_b); + } + // Last one increments gsb. + acc_[0][0].template mma(frag_a[MMAS_K_PER_WARP - 1][0], single_desc_b); + } else { // GMMA supports transposed input: we can just advance SMEM address to the k-th block + // for each Warp in K. + + constexpr int NUM_KGROUPS = Smem_tile_b::BUFFERS_PER_TILE; + constexpr int MMAS_K_PER_GROUP = Mma_tile::MMAS_K / NUM_KGROUPS; + static_assert(MMAS_K_PER_GROUP * NUM_KGROUPS == Mma_tile::MMAS_K); + + uint64_t temp_desc = gmma_desc_b_[0].get_descriptor(0); + int2& tmp = reinterpret_cast(temp_desc); + + constexpr int BYTES_PER_K_GROUP_NO_4LSB = + Mma_tile::K_PER_WARP_GROUP * Mma_tile::N_PER_WARP_GROUP * sizeof(Traits::B_type) / 16; + tmp.x += warp_k * BYTES_PER_K_GROUP_NO_4LSB; + gmma_desc_b_[0].set_descriptor(0, temp_desc); + +#pragma unroll + for (int kbi = 0; kbi < NUM_KGROUPS - 1; kbi++) { +#pragma unroll + for (int ki = 0; ki < MMAS_K_PER_GROUP; ki++) { + fill_frag_a(frag_a[kbi * MMAS_K_PER_GROUP + ki][0]); + // Never increment scoreboard, but check for last kblock. + compute(ki, false, ki == MMAS_K_PER_GROUP - 1); + } + increment_gmma_desc_group(); + } + +#pragma unroll + for (int ki = 0; ki < MMAS_K_PER_GROUP - 1; ki++) { + fill_frag_a(frag_a[(NUM_KGROUPS - 1) * MMAS_K_PER_GROUP + ki][0]); + compute(ki); + } + + fill_frag_a(frag_a[NUM_KGROUPS * MMAS_K_PER_GROUP - 1][0]); + compute(NUM_KGROUPS * MMAS_K_PER_GROUP - 1, true, true); + } + } + + // Fill the input fragment + inline __device__ void fill_frag_a(Fragment a_temp) { +#pragma unroll + for (int idx = 0; idx < Fragment::NUM_REGS; ++idx) { + a_[0].reg(idx) = a_temp.reg(idx); + } + } + + // Load from shared memory. + // we don't actually need this with MHA fused kernel. + inline __device__ void load(Smem_tile_a& smem_a, Smem_tile_b& smem_b, int ki) { + // smem_a.load( a_[ki], ki ); + } + + // The accumulators. + Fragment_accumulator acc_[MMAS_M][MMAS_N]; + + // The fragments to load A. + // Need to think about is is better to declare as Fragment a_? + // for the second GEMM, MMAS_M is most likely 1. (at least for now. ) + Fragment a_[MMAS_M]; + + // one descriptor group per stage, different GMMAs may or maynot share descriptor group + // each descriptor group holds all the descriptors for the entire kblock + + // The descriptor to load B. + typename Smem_tile_b::Gmma_descriptor gmma_desc_b_[MMAS_N]; + uint32_t b_smem_base_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/fragment.h b/csrc/fmha_v2/fmha/hopper/fragment.h new file mode 100644 index 0000000000..0ee3c7e5be --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/fragment.h @@ -0,0 +1,491 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// F R A G M E N T (A) +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Only needed if Operand A is coming from RF. +template +struct Fragment_a, Layout> + : public Fragment { + // A should be coming from RF. + static_assert(A_RF, "A_RF must be true to allocate RF for Operand A.\n"); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Only needed if Operand A is coming from RF. +template +struct Fragment_a, Layout> + : public Fragment { + // A should be coming from RF. + static_assert(A_RF, "A_RF must be true to allocate RF for Operand A.\n"); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Only needed if Operand A is coming from RF. +template +struct Fragment_a, Layout> + : public Fragment { + // A should be coming from RF. + static_assert(GMMA_A_RF == true, "GMMA_A_RF must be true to allocate RF for Operand A.\n"); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Only needed if Operand A is coming from RF. +template +struct Fragment_a, + Layout> + : public Fragment { + // A should be coming from RF. + static_assert(GMMA_A_RF == true, "GMMA_A_RF must be true to allocate RF for Operand A.\n"); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_a, + Layout> + // TODO: Do we need the * 4 or not? + : public Fragment { + static_assert(sizeof(Input_type_A) == 1); + static_assert(sizeof(Input_type_B) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H G M M A . F 1 6 +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// both operands are coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_REGS; ++ii) { + this->reg(ii) = hadd2(this->reg(ii), other.reg(ii)); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_fp16< + Gmma_single_desc_a::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(single_desc_a.get(), single_desc_b.get(), this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// both operands are coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_REGS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_bf16< + Gmma_single_desc_a::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(single_desc_a.get(), single_desc_b.get(), this->regs_); + } +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// +// A is coming from RF; B is coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // The Traits + using Traits = Hopper_hgmma_fp16_traits; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_REGS; ++ii) { + this->reg(ii) = hadd2(this->reg(ii), other.reg(ii)); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_rfa_fp16< + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(a.regs_, single_desc_b.get(), this->regs_); + } +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// +// A is coming from RF; B is coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // The Traits + using Traits = Hopper_hgmma_bf16_traits; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_rfa_bf16< + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(a.regs_, single_desc_b.get(), this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H G M M A . F 3 2 +// +////////////////////////////////////////////////////////////////////////////////////////////////// +// both operands are coming from SMEM +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_fp32< + Gmma_single_desc_a::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(single_desc_a.get(), single_desc_b.get(), this->regs_); + } +}; + +// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A is coming from RF; B is coming from SMEM +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // The Traits + using Traits = Hopper_hgmma_fp32_traits; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + fmha::hgmma_rfa_fp32< + Gmma_single_desc_b::TRANS_MODE == fmha::Gmma_descriptor_transpose::TRANS ? true : false, + GMMA_N, INCREMENT_SCORE_BOARD>(a.regs_, single_desc_b.get(), this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Q G M M A . F 3 2 +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// I G M M A . I N T 8 +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Both operands are coming from SMEM. +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + fmha::igmma_int8_int32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A is coming from RF; B is coming from SMEM + +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // The Traits. + using Traits = Hopper_igmma_int8_int32_traits; + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + fmha::igmma_rfa_int8_int32(a.regs_, single_desc_b.get(), + this->regs_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Fp32 Accumulator A operand from RF and B operand from SMEM +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Add two fragments. + template + inline __device__ void add(Other_fragment_ const& other) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { + this->elt(ii) = this->elt(ii) + other.elt(ii); + } + } + + // The Traits + using Traits = Hopper_qgmma_fp8_fp32_traits; + + // Do the GMMA. + template + inline __device__ void mma(Fragment_a const& a, + Gmma_single_desc_b const& single_desc_b) { + // call hgmma + if (std::is_same_v && std::is_same_v) { + qgmma_rfa_e4m3_e4m3_fp32(a.regs_, single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_rfa_e5m2_e4m3_fp32(a.regs_, single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_rfa_e4m3_e5m2_fp32(a.regs_, single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_rfa_e5m2_e5m2_fp32(a.regs_, single_desc_b.get(), + this->regs_); + } else { + assert(false && "unsupported"); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// fp32 accumulator +// Both operands are coming from SMEM. +template +struct Fragment_accumulator> + : public Fragment { + // The base class. + using Base = Fragment; + + // Do the GMMA. + template + inline __device__ void mma(Gmma_single_desc_a const& single_desc_a, + Gmma_single_desc_b const& single_desc_b) { + if (std::is_same_v && std::is_same_v) { + qgmma_e4m3_e4m3_fp32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_e5m2_e4m3_fp32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_e4m3_e5m2_fp32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } else if (std::is_same_v && std::is_same_v) { + qgmma_e5m2_e5m2_fp32(single_desc_a.get(), single_desc_b.get(), + this->regs_); + } else { + assert(false && "unsupported"); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_saver_tma { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Ctor. + template + inline __device__ Softmax_saver_tma(Params const& params, Head_info const& head_info) + : actual_len_(head_info.actual_seqlen), + local_q_tile_offset_(head_info.local_q_tile_offset), + softmax_sum_ptr_(reinterpret_cast(params.softmax_stats_ptr)), + softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes) { + softmax_max_ptr_ = reinterpret_cast(params.softmax_stats_ptr); + int warp = (threadIdx.x % 128) / Cta_tile::THREADS_PER_WARP; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + // MMA row0 index (8x4 thread layout) + row0_ = warp * Mma_tile::M_PER_MMA / WARPS_M + (lane / 4); + + int sum_s = + params.is_s_padded ? params.s * head_info.bidb : params.cu_q_seqlens[head_info.bidb]; + int token_id = sum_s * params.h + head_info.bidh; + size_t const bh_offset = + token_id * sizeof(float) * 2 + local_q_tile_offset_ * softmax_stats_stride_in_bytes_; + softmax_max_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_; + softmax_sum_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_ + sizeof(float); + }; + + inline __device__ void store(float* p_sum, float* p_max, float sqrt_d, int row_offset, + bool valid_run) { + // Four threads process two rows in mma, each row has one softmax_sum and one softmax_max. + // Here we use one thread to write one softmax element. + float values; + int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + if (lane % 4 < 2) { + values = p_sum[lane % 2]; + } else { + values = p_max[lane % 2] / sqrt_d; + } + if (!valid_run && (lane % 4) < 2) { + values = 1.0; + } + char* dst_ptr = (lane % 4 < 2) ? softmax_sum_ptr_ : softmax_max_ptr_; + size_t off_inside_mma = (lane % 2 == 0) ? row_offset : row_offset + 8; + if (local_q_tile_offset_ + row0_ + off_inside_mma < actual_len_) { + fmha::stg(dst_ptr + off_inside_mma * softmax_stats_stride_in_bytes_, values); + } + } + + // ptr + char* softmax_sum_ptr_ = nullptr; + char* softmax_max_ptr_ = nullptr; + + // the first row's idx + int row0_; + // actual seq length + int const actual_len_; + int const softmax_stats_stride_in_bytes_; + int const local_q_tile_offset_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h b/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h new file mode 100644 index 0000000000..7c9ac43bb8 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h @@ -0,0 +1,1138 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include + +namespace fmha { + +namespace v2 { + +template +struct Gmem_tile_o_hopper {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Not super proud of this. Need to refactor. +// A not optimized way of storing tile_O, without SMEM swizzle. +// STG.32 is going to be used. +template +struct Gmem_tile_o_hopper_16bits { + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 2 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = Mma_tile::M_PER_MMA / 8 / Cta_tile::WARPS_PER_CTA }; + + enum { ROWS = Cta_tile::M }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Mma_tile::N_PER_MMA / 4 / 2 }; + + // The number of valid columns (stored to GMEM) by each thread. + enum { + VALID_COLS_PER_THREAD_FOR_LAST_MMA = (Cta_tile::VALID_N % Mma_tile::N_PER_MMA) == 0 + ? COLS_PER_THREAD + : (Cta_tile::VALID_N % Mma_tile::N_PER_MMA) / 8 + }; + + enum { VALID_MMAS_N = fmha::Div_up::VALUE }; + + static_assert(Cta_tile::VALID_N % 8 == 0, "The valid head dimension needs to be multiple of 8."); + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELTS_PER_THREAD = ROWS_PER_THREAD * COLS_PER_THREAD }; + + // Currently, we assume for o matrix, GMMA M/N shape matches CTA M/N shape. + static_assert(Mma_tile::M_PER_MMA == Cta_tile::M && + Mma_tile::N_PER_MMA * Mma_tile::MMAS_N == Cta_tile::N, + "Currently, we assume for o matrix, GMMA M shape matches CTA M shape. "); + + // Step N for one quad + enum { STEP_N = 8 * BYTES_PER_ELEMENT }; + + // Ctor. + template + inline __device__ Gmem_tile_o_hopper_16bits(Params const& params, Block_info const& block_info, + int tidx, int cta_row_offset = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + actual_seqlen_(block_info.actual_seqlen), + o_ptr_(reinterpret_cast(params.o_ptr)) { + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // int warpgroup_idx = warp / 4; + int warp_idx_within_warpgroup = warp % 4; + + // Compute the position in the sequence (within the CTA for the moment). + int row = warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) + lane / 4; + // Store the row to update the predicates in load. + row_ = cta_row_offset + row; + // Compute the position of the thread in the row. + int col = lane % 4 * ELEMENTS_PER_STG; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = + (int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW; + // Finalize the pointer. + o_ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + int64_t const step_m = 8 * (this->params_o_stride_in_bytes_); + // we assume M = 1. some shortcuts. + static_assert(M == 1); +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + if (row_ + row_idx * 8 >= actual_seqlen_) { + break; + } +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N - 1; ++mma_ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { + uint32_t acc_0 = acc[0][mma_ni].reg(col_idx * ROWS_PER_THREAD + row_idx); + + int64_t offset = + (int64_t)row_idx * step_m + (int64_t)(col_idx + mma_ni * COLS_PER_THREAD) * STEP_N; + fmha::stg(o_ptr_ + offset, acc_0); + } // col_idx + } // mma_ni + + // The last mma_n may not store full elements back to GMEM. + int mma_ni = VALID_MMAS_N - 1; +#pragma unroll + for (int col_idx = 0; col_idx < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ++col_idx) { + uint32_t acc_0 = acc[0][mma_ni].reg(col_idx * ROWS_PER_THREAD + row_idx); + + int64_t offset = + (int64_t)row_idx * step_m + (int64_t)(col_idx + mma_ni * COLS_PER_THREAD) * STEP_N; + fmha::stg(o_ptr_ + offset, acc_0); + } // col_idx + } // row_idx + } + + // Move to the next location. + inline __device__ void move() { + row_ += ROWS; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The pointer. + char* o_ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; + + // The row loaded by this thread. + int row_; + // The length of the sequence loaded by that CTA. + int actual_seqlen_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + 1> // WARPS_K + : public Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_fp16_traits, Cta_tile> { + using Traits = fmha::Hopper_hgmma_fp16_traits; + + using Base = Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_fp16_traits, Cta_tile>; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + 1> // WARPS_K + : public Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_fp32_traits, Cta_tile> { + using Traits = fmha::Hopper_hgmma_fp32_traits; + + using Base = Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_fp32_traits, Cta_tile>; + + using Mma_tile = typename Base::Mma_tile; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + int64_t const step_m = 8 * (this->params_o_stride_in_bytes_); + // we assume M = 1. some shortcuts. + static_assert(M == 1); +#pragma unroll + for (int row_idx = 0; row_idx < Base::ROWS_PER_THREAD; ++row_idx) { + if (this->row_ + row_idx * 8 >= this->actual_seqlen_) { + break; + } +#pragma unroll + for (int mma_ni = 0; mma_ni < Base::VALID_MMAS_N - 1; ++mma_ni) { +#pragma unroll + for (int col_idx = 0; col_idx < Base::COLS_PER_THREAD; ++col_idx) { + // 2 denotes as fp32 --> fp16 + float reg0 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx)); + float reg1 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx) + 1); + uint32_t out = fmha::float2_to_half2(reg0, reg1); + + int64_t offset = (int64_t)row_idx * step_m + + (int64_t)(col_idx + mma_ni * Base::COLS_PER_THREAD) * Base::STEP_N; + fmha::stg(this->o_ptr_ + offset, out); + } // col_idx + } // mma_ni + + // The last mma_n may not store full elements back to GMEM. + int mma_ni = Base::VALID_MMAS_N - 1; +#pragma unroll + for (int col_idx = 0; col_idx < Base::VALID_COLS_PER_THREAD_FOR_LAST_MMA; ++col_idx) { + // 2 denotes as fp32 --> fp16 + float reg0 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx)); + float reg1 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx) + 1); + uint32_t out = fmha::float2_to_half2(reg0, reg1); + + int64_t offset = (int64_t)row_idx * step_m + + (int64_t)(col_idx + mma_ni * Base::COLS_PER_THREAD) * Base::STEP_N; + fmha::stg(this->o_ptr_ + offset, out); + } // col_idx + } // row_idx + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + 1> // WARPS_K + : public Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_bf16_traits, Cta_tile> { + using Traits = fmha::Hopper_hgmma_bf16_traits; + + using Base = Gmem_tile_o_hopper_16bits< + fmha::Hopper_hgmma_bf16_traits, Cta_tile>; + + using Mma_tile = typename Base::Mma_tile; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + int64_t const step_m = 8 * (this->params_o_stride_in_bytes_); + // we assume M = 1. some shortcuts. + static_assert(M == 1); +#pragma unroll + for (int row_idx = 0; row_idx < Base::ROWS_PER_THREAD; ++row_idx) { + if (this->row_ + row_idx * 8 >= this->actual_seqlen_) { + break; + } +#pragma unroll + for (int mma_ni = 0; mma_ni < Mma_tile::VALID_MMAS_N - 1; ++mma_ni) { +#pragma unroll + for (int col_idx = 0; col_idx < Base::COLS_PER_THREAD; ++col_idx) { + // 2 denotes as fp32 --> bf16 + float reg0 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx)); + float reg1 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx) + 1); + uint32_t out = fmha::float2_to_bf16_x2(reg0, reg1); + + int64_t offset = (int64_t)row_idx * step_m + + (int64_t)(col_idx + mma_ni * Base::COLS_PER_THREAD) * Base::STEP_N; + fmha::stg(this->o_ptr_ + offset, out); + } // row_idx + } // col_idx + + // The last mma_n may not store full elements back to GMEM. + int mma_ni = Base::VALID_MMAS_N - 1; +#pragma unroll + for (int col_idx = 0; col_idx < Base::VALID_COLS_PER_THREAD_FOR_LAST_MMA; ++col_idx) { + // 2 denotes as fp32 --> bf16 + float reg0 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx)); + float reg1 = acc[0][mma_ni].elt(2 * (col_idx * Base::ROWS_PER_THREAD + row_idx) + 1); + uint32_t out = fmha::float2_to_bf16_x2(reg0, reg1); + + int64_t offset = (int64_t)row_idx * step_m + + (int64_t)(col_idx + mma_ni * Base::COLS_PER_THREAD) * Base::STEP_N; + fmha::stg(this->o_ptr_ + offset, out); + } // row_idx + } // mma_ni + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + 2> // WARPS_K + : public fmha::v2::Hmma_gmem_tile_o< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + /*CTAS_PER_HEAD=*/1, + /*BYTES_PER_STG=*/16> { + using Traits = fmha::Hopper_hgmma_fp16_traits; + using Base = fmha::v2::Hmma_gmem_tile_o; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + 2> // WARPS_K + : public fmha::v2::Hmma_gmem_tile_o< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + /*CTAS_PER_HEAD=*/1, + /*BYTES_PER_STG=*/16> { + using Traits = fmha::Hopper_hgmma_fp32_traits; + using Base = fmha::v2::Hmma_gmem_tile_o; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + 2> // WARPS_K + : public fmha::v2::Hmma_gmem_tile_o< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + /*CTAS_PER_HEAD=*/1, + /*BYTES_PER_STG=*/16> { + using Traits = fmha::Hopper_hgmma_bf16_traits; + using Base = fmha::v2::Hmma_gmem_tile_o; + + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, Shared&&, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o, + Cta_tile, CTAS_PER_HEAD> + : public Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_hgmma_fp16_traits; + + using Base = Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, std::nullptr_t{} /* dummy obj */, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o, + Cta_tile, CTAS_PER_HEAD> + : public Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_hgmma_fp32_traits; + + using Base = Gmem_tile_o_hopper< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, std::nullptr_t{} /* dummy obj */, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o, + Cta_tile, CTAS_PER_HEAD> + : public Gmem_tile_o_hopper< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_hgmma_bf16_traits; + + using Base = Gmem_tile_o_hopper< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, std::nullptr_t{} /* dummy obj */, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_gmma_32bit_8bit { + static_assert(sizeof(typename Traits::Accumulator_type) == 4); + static_assert(sizeof(typename Traits::C_type) == 1); + // This is for non-splitk GMMA BMM2. + static_assert(Cta_tile::WARPS_K == 1); + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 4 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = 1 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + enum { ROWS = Cta_tile::M }; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = Mma_tile::M_PER_MMA / 8 / Cta_tile::WARPS_M }; + + static_assert(ROWS_PER_THREAD == 2); + static_assert(ROWS_PER_THREAD == Mma_tile::ROWS_PER_THREAD); + + // The number of columns access by each thread. + // The number of core matrices in N. + enum { COLS_PER_THREAD = Mma_tile::N_PER_MMA / 4 / 2 }; // N_PER_MMA = GMMA_N + + static_assert(COLS_PER_THREAD == Mma_tile::COLS_PER_THREAD / 2); + // Assume there is an even number of core matrices, such that we can pack two + static_assert(COLS_PER_THREAD % 2 == 0); + + // Number of valid N columns. + enum { VALID_N = Cta_tile::VALID_N }; + + // The number of valid columns (stored to GMEM) by each thread. + enum { + VALID_COLS_PER_THREAD_FOR_LAST_MMA = + (VALID_N % Mma_tile::N_PER_MMA) == 0 ? COLS_PER_THREAD : (VALID_N % Mma_tile::N_PER_MMA) / 8 + }; + + enum { VALID_MMAS_N = fmha::Div_up::VALUE }; + + static_assert(VALID_N % 8 == 0, "The valid head dimension needs to be multiple of 8."); + + // The number of N elements must be multiple of 16 in order to pack 4 elements as uint32_t. + enum { PACK_4_ELTS = VALID_N % 16 == 0 }; + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELTS_PER_THREAD = ROWS_PER_THREAD * COLS_PER_THREAD * 2 }; + + // Currently, we assume for o matrix, GMMA M shape matches CTA M shape. + static_assert(Mma_tile::M_PER_MMA == Cta_tile::M && + Mma_tile::N_PER_MMA * Mma_tile::MMAS_N == Cta_tile::N, + "Currently, we assume for o matrix, GMMA M/N shape matches CTA M/N shape. "); + + // Step N for one quad (pack 4 elements for a thread, so 16 elements for a quad) + enum { STEP_N = 16 * BYTES_PER_ELEMENT }; + + // The number of head_dimension groups. + enum { N_GROUPS = fmha::Div_up::VALUE }; + + // The head_dimension per group. + enum { N_PER_GROUP = Cta_tile::N / N_GROUPS }; + + static_assert(N_GROUPS * N_PER_GROUP == Cta_tile::N); + + // The head_dimension bytes per group + enum { N_BYTES_PER_GROUP = Cta_tile::N * BYTES_PER_ELEMENT / N_GROUPS }; + + // Pack 2x4 core matrices, use STSMx4 + enum { STSM_PER_MMA = COLS_PER_THREAD / 4 }; + + // The number of registers per 16x16 block + enum { REGS_PER_QUAD = 8 }; + + // Bytes per bank + enum { BYTES_PER_BANK = 16 }; + + // The number of banks in N per group + enum { N_BANKS_PER_GROUP = N_BYTES_PER_GROUP / BYTES_PER_BANK }; + + enum { USE_TMA_STORE = USE_TMA_STORE_ }; + + // Ctor. + template + inline __device__ Gmem_tile_o_gmma_32bit_8bit(Params const& params, Block_info const& block_info, + Shared& shared, int tidx, int cta_row_offset = 0) + : Gmem_tile_o_gmma_32bit_8bit( + params.o_ptr, params.o_stride_in_bytes, block_info, tidx, +#ifdef GENERATE_CUBIN + // Specialized for trt-llm generated cubins only. + params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2, +#else + params.scale_bmm2, +#endif + cta_row_offset, 0, + __nvvm_get_smem_pointer(reinterpret_cast( + &shared.smem_o[__shfl_sync(0xffffffff, threadIdx.x / 128, 0)][0])), + ¶ms.tma_desc_o, params.h) { + } + + template + inline __device__ Gmem_tile_o_gmma_32bit_8bit(void* o_ptr, int o_stride_in_bytes, + Block_info const& block_info, int tidx, + uint32_t scale_bmm2, int cta_row_offset = 0, + int mat_offset = 0, uint32_t smem_base = 0, + cudaTmaDesc const* desc_o = nullptr, + int head_num = 0) + : params_o_stride_in_bytes_(o_stride_in_bytes), + actual_seqlen_(block_info.actual_seqlen), + o_ptr_(reinterpret_cast(o_ptr)), + params_scale_bmm2_(scale_bmm2), + smem_base_(smem_base), + desc_o_(desc_o) { + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // int warpgroup_idx = warp / 4; + int warp_idx_within_warpgroup = warp % 4; + + if (USE_TMA_STORE) { + // The head index + bidh_ = block_info.bidh; + // The lane id + lane_ = lane; + // The start row index for current batch + int row_curr_batch = (block_info.bidx - block_info.bidh) / head_num; + // The row index offset of current warp + int row_offset_warp = cta_row_offset + warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4); + // The row index for the current warp + row_tma_ = row_offset_warp + row_curr_batch; + // The valid rows for the current warp. Each warp writes from 0 to 16 rows + num_valid_rows_ = min(Mma_tile::M_PER_MMA / 4, actual_seqlen_ - row_offset_warp); + num_valid_rows_ = max(num_valid_rows_, 0); + // WARNING: Without this line, the predicate will not behavior as expected for unknown reason. + num_valid_rows_ = __shfl_sync(0xffffffff, num_valid_rows_, 0); + // Compute the smem base for STSM + smem_base_ += + warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) * Cta_tile::N * BYTES_PER_ELEMENT + + (warp / 4) * Mma_tile::M_PER_MMA * Cta_tile::N * BYTES_PER_ELEMENT; + // Compute gmem base for STG in tail case + o_ptr_ += row_tma_ * params_o_stride_in_bytes_ + bidh_ * BYTES_PER_ROW; + } else { + // Compute the position in the sequence (within the CTA for the moment). + int row = warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) + lane / 4; + // Store the row to update the predicates in load. + row_ = cta_row_offset + row; + // Compute the position of the thread in the row. + col_ = lane % 4 * ELEMENTS_PER_STG; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = + (int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW; + // Finalize the pointer. + o_ptr_ += row_offset + col_ * BYTES_PER_ELEMENT; + } + + // REVIEW: need heads_interleaved option for non-warp-specialized QGMMA + LDGSTS kernels. + // // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + // int64_t row_offset = (int64_t) row_ * params_o_stride_in_bytes_; + // // Add the block index. + // int64_t idx = block_info.bidx; + // if(NUM_MATS > 1) { + // if( HEADS_INTERLEAVED ) { + // idx = block_info.bidx * NUM_MATS + mat_offset; + // } else { + // idx = (block_info.sum_s * NUM_MATS + mat_offset) * block_info.num_heads + + // block_info.bidh; + // } + // } + // // Assemble the final pointer. + // o_ptr_ += row_offset + idx * BYTES_PER_ROW + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + static_assert(Accumulators::NUM_ELTS == ELTS_PER_THREAD); + static_assert(COLS_PER_THREAD / 2 * ROWS_PER_THREAD * 4 == ELTS_PER_THREAD); + + // we assume M = N = 1. some shortcuts. + static_assert(M == 1); + + if (USE_TMA_STORE) { + static_assert(COLS_PER_THREAD % 4 == 0); + static_assert(ROWS_PER_THREAD == 2); + + int const swizzled_row = (lane_ % 16); + int const swizzled_col = (lane_ / 16); + constexpr int max_swizzle_id = N_BYTES_PER_GROUP / 16; + constexpr int swizzle_row_divider = 128 / N_BYTES_PER_GROUP; + + uint32_t stsm_addr[VALID_MMAS_N][STSM_PER_MMA]; +// Compute swizzled smem address +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N; ++mma_ni) { +#pragma unroll + for (int ci = 0; ci < STSM_PER_MMA; ++ci) { + int const col_bank = ((mma_ni)*STSM_PER_MMA + ci) * 2 + swizzled_col; + int const di = col_bank / N_BANKS_PER_GROUP; // which N group it belongs to + stsm_addr[mma_ni][ci] = smem_base_ + di * 16 * N_BYTES_PER_GROUP + // group dimension + (((swizzled_row / swizzle_row_divider) % max_swizzle_id) ^ + (col_bank % N_BANKS_PER_GROUP)) * + BYTES_PER_BANK + // column dimension + swizzled_row * N_BYTES_PER_GROUP; // row dimension + } + } + +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N; ++mma_ni) { +#pragma unroll + for (int ci = 0; ci < STSM_PER_MMA; ++ci) { + uint32_t dst[4]; + uint4 src[4]; + + /* + * Each STSMx4 produces a 16x32 block, that is 2x4 core matrices + * ----------------- + * | 0 | 2 | 4 | 6 | + * ----------------- + * | 1 | 3 | 5 | 7 | + * ----------------- + * + * Consider the entire warp, src[0] holds matrices 0,2; src[1] holds matrices 1,3; + * src[3] holds matrices 4,6; src[4] holds matrices 5,7. + */ + src[0].x = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 0); + src[0].y = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 4); + src[0].z = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 1); + src[0].w = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 5); + + src[1].x = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 2); + src[1].y = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 6); + src[1].z = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 3); + src[1].w = acc[0][mma_ni].reg((ci * 2 + 0) * REGS_PER_QUAD + 7); + + src[2].x = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 0); + src[2].y = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 4); + src[2].z = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 1); + src[2].w = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 5); + + src[3].x = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 2); + src[3].y = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 6); + src[3].z = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 3); + src[3].w = acc[0][mma_ni].reg((ci * 2 + 1) * REGS_PER_QUAD + 7); + + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; +// Packs the 32bit values to 8bit. +// Depending on the type, applies extra scaling with parameter scale_bmm2. +#pragma unroll + for (int i = 0; i < 4; ++i) { +#ifdef UNIFIED_EPILOGUE_SCALE + dst[i] = Acc_packer::run(this, src[i]); +#else + dst[i] = Acc_packer::run(this, src[i]); +#endif + } + stsm(stsm_addr[mma_ni][ci], *reinterpret_cast(&dst[0])); + } + } + + // TODO: Interleave STSM and UTMASTG of two N groups + constexpr int MAX_ROWS_PER_WARP = Mma_tile::M_PER_MMA / 4; + if (num_valid_rows_ == MAX_ROWS_PER_WARP) { + fence_view_async_shared(); +#pragma unroll + for (int di = 0; di < N_GROUPS; ++di) { + const int32_t coords[3] = {di * N_PER_GROUP, bidh_, row_tma_}; + fmha::utmastg<3, fmha::cudaTmaDescType::TILED>( + desc_o_, smem_base_ + di * 16 * N_BYTES_PER_GROUP, coords); + } + tmastg_arrive(); + tmastg_wait(); + } else if (num_valid_rows_ > 0) { + // Use LDS.64 + STG.64 to store num_valid_rows_ x N tile + constexpr int BYTES_PER_THREAD = 8; + static_assert((VALID_N % BYTES_PER_THREAD) == 0, "VALID_N must be divided by 8 for STG.64"); + // Number of valid rows + int row_size = num_valid_rows_; + // Number of threads per row. Each thread read/write 8B (8 elements). + constexpr int THREADS_PER_ROW = N_BYTES_PER_GROUP / 8; + // Number of rows read/written by a warp + static_assert(Cta_tile::THREADS_PER_WARP % THREADS_PER_ROW == 0, + "A warp must reads full rows"); + constexpr int ROWS_PER_WARP = Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW; + // GMEM stride in M dimension + int64_t const step_m = (this->params_o_stride_in_bytes_); + // Initial column index + int const ci = lane_ % THREADS_PER_ROW; + int const bank_idx = (ci * BYTES_PER_THREAD) / BYTES_PER_BANK; + int const bank_offset = (ci * BYTES_PER_THREAD) % BYTES_PER_BANK; + +#pragma unroll + for (int di = 0; di < N_GROUPS; ++di) { + // Detect GMEM index out of bound + if ((di * N_BYTES_PER_GROUP + ci * BYTES_PER_THREAD) >= BYTES_PER_ROW) { + break; + } +#pragma unroll + for (int ri = lane_ / THREADS_PER_ROW; ri < row_size; ri += ROWS_PER_WARP) { + // Create the swizzled offset + uint32_t smem_offset = + di * 16 * N_BYTES_PER_GROUP + ri * N_BYTES_PER_GROUP + + (((ri / swizzle_row_divider) % max_swizzle_id) ^ bank_idx) * BYTES_PER_BANK + + bank_offset; + uint2 buffer; + fmha::lds(buffer, smem_base_ + smem_offset); + int64_t gmem_offset = + (int64_t)ri * step_m + di * N_BYTES_PER_GROUP + ci * BYTES_PER_THREAD; + fmha::stg(o_ptr_ + gmem_offset, buffer); + } + } + } + } else { + int64_t const step_m = 8 * (this->params_o_stride_in_bytes_); + +#pragma unroll + for (int ri = 0; ri < ROWS_PER_THREAD; ++ri) { + if (row_ + ri * 8 >= actual_seqlen_) { + break; + } + +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N - 1; ++mma_ni) { +// Iterate over 16 columns to pack 4 values per thread. +#pragma unroll + for (int ci = 0; ci < COLS_PER_THREAD / 2; ++ci) { + // Assuming EVEN,EVEN,ODD,ODD column pattern due to packing of V. + uint4 src; + src.x = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 0); // 0 + src.y = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 0); // 4 + src.z = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1); // 1 + src.w = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1); // 5 + + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; + // Packs the 32bit values to 8bit. + // Depending on the type, applies extra scaling with parameter scale_bmm2. +#ifdef UNIFIED_EPILOGUE_SCALE + uint32_t dst = Acc_packer::run(this, src); +#else + uint32_t dst = Acc_packer::run(this, src); +#endif + + int64_t offset = + (int64_t)ri * step_m + (int64_t)(ci + mma_ni * COLS_PER_THREAD / 2) * STEP_N; + fmha::stg(o_ptr_ + offset, dst); + } // ci + } // mma_ni + + if constexpr (PACK_4_ELTS) { + // The last mma_n may not store full elements back to GMEM. + int mma_ni = VALID_MMAS_N - 1; +// Iterate over 16 columns to pack 4 values per thread. +#pragma unroll + for (int ci = 0; ci < VALID_COLS_PER_THREAD_FOR_LAST_MMA / 2; ++ci) { + // Assuming EVEN,EVEN,ODD,ODD column pattern due to packing of V. + uint4 src; + src.x = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 0); // 0 + src.y = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 0); // 4 + src.z = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1); // 1 + src.w = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1); // 5 + + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; + // Packs the 32bit values to 8bit. + // Depending on the type, applies extra scaling with parameter scale_bmm2. +#ifdef UNIFIED_EPILOGUE_SCALE + uint32_t dst = Acc_packer::run(this, src); +#else + uint32_t dst = Acc_packer::run(this, src); +#endif + + int64_t offset = + (int64_t)ri * step_m + (int64_t)(ci + mma_ni * COLS_PER_THREAD / 2) * STEP_N; + fmha::stg(o_ptr_ + offset, dst); + } // ci + } else { + // The last mma_n may not store full elements back to GMEM. + int mma_ni = VALID_MMAS_N - 1; +// Iterate over 16 columns to pack 4 values per thread (2 uint2). +#pragma unroll + for (int ci = 0; ci < fmha::Div_up::VALUE; ++ci) { + // Assuming EVEN,EVEN,ODD,ODD column pattern due to packing of V. + uint2 src0, src1; + src0.x = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 0); // 0 + src0.y = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 0); // 4 + src1.x = acc[0][mma_ni].reg(((2 * ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1); // 1 + src1.y = acc[0][mma_ni].reg(((2 * ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1); // 5 + + using Src_type = typename Traits::Accumulator_type; + using Dst_type = typename Traits::C_type; +#ifdef UNIFIED_EPILOGUE_SCALE + uint16_t dst0 = Acc_packer::run(this, src0); + uint16_t dst1 = Acc_packer::run(this, src1); +#else + uint16_t dst0 = Acc_packer::run(this, src0); + uint16_t dst1 = Acc_packer::run(this, src1); +#endif + + // 4 elements per thread, so 16 elements per loop. + int col_idx = (ci + mma_ni * COLS_PER_THREAD / 2) * 16; + + int64_t offset = (int64_t)ri * step_m + (int64_t)(col_idx)*BYTES_PER_ELEMENT; + + if (col_idx + col_ < VALID_N) { + fmha::stg(o_ptr_ + offset, dst0); + } + + if (col_idx + col_ + 2 < VALID_N) { + fmha::stg(o_ptr_ + offset + 2 * BYTES_PER_ELEMENT, dst1); + } + } // ci + } + } // ri + } + } + + // Move to the next location. + inline __device__ void move() { + row_ += ROWS; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // The pointer. + char* o_ptr_; + // Is the thread active for the last STG? + int is_active_for_last_stg_; + + // The row, col loaded by this thread. + int row_, col_; + // The length of the sequence loaded by that CTA. + int actual_seqlen_; + + // Scaling factor; this usually means QKV descale factor in actuality + uint32_t params_scale_bmm2_; + + // Smem buffer for TMASTG + uint32_t smem_base_; + cudaTmaDesc const* desc_o_; + + int lane_; + int row_tma_; + int num_valid_rows_; + int bidh_; + + bool const params_enable_i2f_trick_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper_32bit_8bit {}; + +template +struct Gmem_tile_o_hopper_32bit_8bit + : public Gmem_tile_o_gmma_32bit_8bit { + // The Base class. + using Base = Gmem_tile_o_gmma_32bit_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o_hopper_32bit_8bit(Params const& params, + Block_info const& block_info, Shared& shared, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, shared, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper_32bit_8bit + : public Gmem_tile_o_8bit { + // The Base class. + using Base = Gmem_tile_o_8bit; + + // Ctor. + template + inline __device__ Gmem_tile_o_hopper_32bit_8bit(Params const& params, + Block_info const& block_info, Shared& shared, + int tidx, int cta_row_offset = 0) + : Base(params, block_info, shared, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_hopper< + fmha::Hopper_qgmma_fp8_fp32_traits, Cta_tile, + CTAS_PER_HEAD> + : public Gmem_tile_o_hopper_32bit_8bit< + fmha::Hopper_qgmma_fp8_fp32_traits, + Cta_tile, Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_qgmma_fp8_fp32_traits; + + using Base = Gmem_tile_o_hopper_32bit_8bit< + fmha::Hopper_qgmma_fp8_fp32_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o_hopper(Params const& params, Block_info const& block_info, + Shared& shared, int tidx, int cta_row_offset = 0) + : Base(params, block_info, shared, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o< + fmha::Hopper_igmma_int8_int32_traits, Cta_tile, + CTAS_PER_HEAD> + : public Gmem_tile_o_hopper_32bit_8bit< + fmha::Hopper_igmma_int8_int32_traits, + Cta_tile, Cta_tile::WARPS_K> { + // The traits class. + using Traits = fmha::Hopper_igmma_int8_int32_traits; + + using Base = Gmem_tile_o_hopper_32bit_8bit< + fmha::Hopper_igmma_int8_int32_traits, Cta_tile, + Cta_tile::WARPS_K>; + + // Ctor. + template + inline __device__ Gmem_tile_o(Params const& params, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + : Base(params, block_info, std::nullptr_t{} /* dummy obj */, tidx, cta_row_offset) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_qgmma_fp32_16bits { + // The associated MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of elements per STG. + enum { ELEMENTS_PER_STG = 2 }; + + // The size in bytes of each element. + enum { BYTES_PER_ELEMENT = 2 }; + + // The size of each STG. + enum { BYTES_PER_STG = ELEMENTS_PER_STG * BYTES_PER_ELEMENT }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = Cta_tile::VALID_N * BYTES_PER_ELEMENT }; + + // The number of rows accessed by each thread. + enum { ROWS_PER_THREAD = Mma_tile::M_PER_MMA / 8 / Cta_tile::WARPS_PER_CTA }; + + enum { ROWS = Cta_tile::M }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Mma_tile::N_PER_MMA / 4 / 2 }; + + // The number of valid columns (stored to GMEM) by each thread. + enum { + VALID_COLS_PER_THREAD_FOR_LAST_MMA = (Cta_tile::VALID_N % Mma_tile::N_PER_MMA) == 0 + ? COLS_PER_THREAD + : (Cta_tile::VALID_N % Mma_tile::N_PER_MMA) / 8 + }; + + enum { VALID_MMAS_N = fmha::Div_up::VALUE }; + + static_assert(Cta_tile::VALID_N % 8 == 0, "The valid head dimension needs to be multiple of 8."); + + // The number of accumulator held by each thread, per HGMMA instruction. + enum { ELTS_PER_THREAD = ROWS_PER_THREAD * COLS_PER_THREAD }; + + // Currently, we assume for o matrix, GMMA M/N shape matches CTA M/N shape. + static_assert(Mma_tile::M_PER_MMA == Cta_tile::M && + Mma_tile::N_PER_MMA * Mma_tile::MMAS_N == Cta_tile::N, + "Currently, we assume for o matrix, GMMA M shape matches CTA M shape. "); + + // Step N for one quad + enum { STEP_N = 8 * BYTES_PER_ELEMENT }; + + // Ctor. + template + inline __device__ Gmem_tile_o_qgmma_fp32_16bits(Params const& params, + Block_info const& block_info, Shared&&, int tidx, + int cta_row_offset = 0) + : params_o_stride_in_bytes_(params.o_stride_in_bytes), + params_scale_bmm2_( +#ifdef GENERATE_CUBIN + // Specialized for trt-llm generated cubins only. + params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2 +#else + params.scale_bmm2 +#endif + ), + actual_seqlen_(block_info.actual_seqlen), + o_ptr_(reinterpret_cast(params.o_ptr)) { + static_assert(!std::is_same::value, "Check constructor argument type!"); + // Decompose the position of the thread into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + int warp_idx_within_warpgroup = warp % 4; + + // Compute the position in the sequence (within the CTA for the moment). + int row = warp_idx_within_warpgroup * (Mma_tile::M_PER_MMA / 4) + lane / 4; + // Store the row to update the predicates in load. + row_ = cta_row_offset + row; + // Compute the position of the thread in the row. + // echo loop handles 2 cores, so x2 (this is the difference to Gmem_tile_o_hopper_16bits) + int col = lane % 4 * ELEMENTS_PER_STG * 2; + + // The offset of the 1st row written by the thread. We store the P matrix interleaved. + int64_t row_offset = + (int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW; + // Finalize the pointer. + o_ptr_ += row_offset + col * BYTES_PER_ELEMENT; + } + + // Store data to memory. + template + inline __device__ void store(Accumulators const (&acc)[M][N]) { + int64_t const step_m = 8 * params_o_stride_in_bytes_; +#ifdef UNIFIED_EPILOGUE_SCALE + constexpr bool Scale = false; +#else + constexpr bool Scale = true; +#endif +#define STORE_COLUMNS() \ + { \ + /* we assume M = 1. some shortcuts. */ \ + static_assert(M == 1); \ + uint4 _src = { \ + .x = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2), \ + .y = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2), \ + .z = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1), \ + .w = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1), \ + }; \ + uint2 _dst = Acc_packer::run(this, _src); \ + int64_t _offset = (int64_t)ri * step_m + (int64_t)(ci + mma_ni * COLS_PER_THREAD) * STEP_N; \ + fmha::stg(o_ptr_ + _offset, _dst); \ + } + +#pragma unroll + for (int ri = 0; ri < ROWS_PER_THREAD; ri++) { + if (row_ + ri * 8 >= actual_seqlen_) { + break; + } +#pragma unroll + for (int mma_ni = 0; mma_ni < VALID_MMAS_N - 1; ++mma_ni) { +#pragma unroll + for (int ci = 0; ci < COLS_PER_THREAD; ci += 2) { + STORE_COLUMNS() + } + } + // The last mma_n may not store full elements back to GMEM. + int mma_ni = VALID_MMAS_N - 1; +#pragma unroll + for (int ci = 0; ci < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ci += 2) { + STORE_COLUMNS() + } + } + } + + // Move to the next location. + inline __device__ void move() { + row_ += ROWS; + o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; + } + + // The stride between rows for the QKV matrice. + int64_t params_o_stride_in_bytes_; + // Scaling factor; this usually means QKV descale factor in actuality + uint32_t params_scale_bmm2_; + // The pointer. + char* o_ptr_; + // The row loaded by this thread. + int row_; + // The length of the sequence loaded by that CTA. + int actual_seqlen_; +}; + +} // namespace v2 + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h b/csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h new file mode 100644 index 0000000000..5ee0ac50d1 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h @@ -0,0 +1,146 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include + +namespace fmha { +namespace v2 { + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The number of bits per element. + int BITS_PER_ELEMENT, + // The number of rows of Q, K or V loaded by this tile. + int ROWS_, + // The number of columns. + int COLS, + // Do we use LDGSTS? + bool USE_LDGSTS_, + // Are attention heads interleaved? + bool HEADS_INTERLEAVED, + // The number of matrices + int NUM_MATS = 3> +struct Gmem_tile_tma_qkv { + // The size of each LDG. + enum { BYTES_PER_LDG = 16 }; + + // The size of a row in bytes. + enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; + + // The number of threads to load a "row" of the matrix. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; + + // The number of "rows" loaded per LDG. + enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = ROWS_ }; + + // The number of LDGs needed to load a chunk of the Q matrix. + enum { LDGS = fmha::Div_up::VALUE }; + + // The number of predicate registers. + enum { PRED_REGS = fmha::Compute_number_of_pred_regs::VALUE }; + + // Is it Hopper? + enum { + IS_HOPPER = std::is_same::value == true + }; + + // Make sure we use a single register to store predicates. Do not throw for Hopper for now. + static_assert(!USE_LDGSTS_ || PRED_REGS == 1 || IS_HOPPER, ""); + + // We do not use LDGSTS (for the moment). + enum { USE_LDGSTS = USE_LDGSTS_ }; + + // TMA DIMS, hard coded for now + enum { TMA_DIMS = 3 }; + + // TMA DESC type, hard coded for now + static constexpr fmha::cudaTmaDescType TMA_DESC_TYPE = fmha::cudaTmaDescType::TILED; + + // Ctor. + template + inline __device__ Gmem_tile_tma_qkv(Params const& params, cudaTmaDesc const* p_desc, + int qkv_offset, Block_info const& block_info, int tidx, + int cta_row_offset = 0) + // in PACKED_QKV, q_stride = k_stride = v_stride + : params_qkv_stride_in_bytes_(params.q_stride_in_bytes), + actual_seqlen_(block_info.actual_seqlen), + qkv_ptr_(reinterpret_cast(params.qkv_ptr)), + p_desc_(p_desc) { + // Both MQA and GQA will use non HEADS_INTERLEAVED layout + if (params.h_kv < params.h) { + // QKV layout [b, s, [q_hd, k_h'd, v_h'd]] + int const hi = block_info.bidh; + int const hi_kv = block_info.bidh / (params.h / params.h_kv); + if (qkv_offset == 0) { // Q tensor + coord[0] = hi * params.d; + } else if (qkv_offset == 1) { // K tensor + coord[0] = params.h * params.d + hi_kv * params.d; + } else if (qkv_offset == 2) { // V tensor + coord[0] = params.h * params.d + params.h_kv * params.d + hi_kv * params.d; + } + } else { + coord[0] = qkv_offset * params.d + block_info.bidh * params.d * 3; + } + // coord[1] = block_info.bidb * params.s; // should be params.s * batch_idx + // coord[1] do not need to be adjusted per batch. + // since the gmem_ptr in tma desc is set per batch and already adjusted. + coord[1] = block_info.sum_s; + coord[2] = 0; + } + + // Store data to shared memory. + template + inline __device__ void commit(Smem_tile& smem_tile) {} + + // Load data from memory. + template + inline __device__ void load(Smem_tile& smem_tile) { + smem_tile.template store(p_desc_, coord); + } + + // Store data to memory. + inline __device__ void store(uint4 const (&data)[LDGS]) {} + + // Move the pointer to the next location. + // only needed by matrix Q. + inline __device__ void move() { + // coord[1] is incremented by STEP size. + coord[1] += ROWS; + } + + // The stride between rows for the QKV matrice. + int64_t params_qkv_stride_in_bytes_; + // The pointer. + char* qkv_ptr_; + // The register to store predicates. + uint32_t preds_[PRED_REGS]; + // The fetch registers. + uint4 fetch_[LDGS]; + // Keep track of the row the thread is processing as we move the tile. + int row_; + // The sequence length. + int actual_seqlen_; + // tma descriptor + cudaTmaDesc const* p_desc_; + // coord use by TMA. For now hard code to 3D. + int32_t coord[3]; +}; + +} // namespace v2 +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/gmma_descriptor.h b/csrc/fmha_v2/fmha/hopper/gmma_descriptor.h new file mode 100644 index 0000000000..8b4129e343 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/gmma_descriptor.h @@ -0,0 +1,547 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// whether transpose is applied on the smem before GMMA math execution +// if TN, notrans is applied to both A and B. as GMMA expects the data +// to be in TN format. +// if NT, trans is applied to both A and B. +//////////////////////////////////////////////////////////////////////////////////////////////////// +enum class Gmma_descriptor_transpose { TRANS, NOTRANS }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Gmma descriptor mode +// 2 bits to specify the descriptor mode. +//////////////////////////////////////////////////////////////////////////////////////////////////// +enum class Gmma_descriptor_mode { SWIZZLE_NONE = 0, SWIZZLE_128B, SWIZZLE_64B, SWIZZLE_32B }; +constexpr uint32_t GMMA_DESCRIPTOR_MODE_BITS = 2; +constexpr uint32_t GMMA_DESCRIPTOR_MODE_SHIFT = 62; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// number of descriptor per GMMA group to be actually allocated per kblock +//////////////////////////////////////////////////////////////////////////////////////////////////// +enum class Gmma_descriptor_size { + ONE, + TWO, // not yet implemented. might be needed for 64xNxK tile size. + // as many as needed (kblock / gmma_k). we may not prefer to use this as we may run out of UR + // budget + ALL +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// a single desc that has the info and bits +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Single_descriptor { + public: + // trans mode + static constexpr Gmma_descriptor_transpose TRANS_MODE = Gmma_trans; + + // set the single desc + inline __device__ void set(uint64_t const& desc_) { desc = desc_; } + + // get the single desc + inline __device__ uint64_t get() const { return desc; } + + private: + // the descriptor, each of 64 bit + uint64_t desc; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// for a +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Gmma_descriptor_a { + public: + // The type of the Single Descriptor + using Single_desc = Single_descriptor; + + // Transpose Mode + static constexpr Gmma_descriptor_transpose TRANS_MODE = Gmma_trans; + + // The number of descriptors per 64xNblockxKblock. + static constexpr Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = Gmma_vector_size; + + // Currently the number of descriptors per 64xNblockxKblock is always One + // Historically we have supported more descriptors. But that has proven to + // be less performant as it consumes too many uniform registers. + // During the process of refactoring we have decided to only support allocating + // one desc per 64xNblockxKblock. If needed in the future, we can support + // more desc. + static_assert(Gmma_vector_size == Gmma_descriptor_size::ONE, + "Currently, only Mblock/64 desc is allocated per kgroup\n"); + + // Interleaved Mode is currently not supported. + // static_assert to avoid accidentally instantiate it. + static_assert(Gmma_mode != Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, SWIZZLE_NONE mode is not implemented. \n"); + + // byte per leading dim (row if TN, column is NT) must be 128 + enum { BYTES_PER_LEADING_DIM = 128 }; + + // bytes per element + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // the number of descriptors per kblock is related to GMMA shape and kblock size + enum { + NUM_DESCRIPTORS = (Gmma_vector_size == Gmma_descriptor_size::ALL) ? Cta_tile::K / GMMA_K : 1 + }; + + // the number of descriptors per 128 byte in k dimension (leading dim) + // NUM_DESCRIPTORS_PER_128B_IN_K is really only needed if leading dim is K + enum { + NUM_DESCRIPTORS_PER_128B_IN_K = (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B && + Gmma_trans == Gmma_descriptor_transpose::NOTRANS) + ? BYTES_PER_LEADING_DIM / ((GMMA_K * BITS_PER_ELEMENT) / 8) + : NUM_DESCRIPTORS + }; + + static constexpr uint32_t BYTES_PER_GMMA_K = GMMA_K * BITS_PER_ELEMENT / 8; // 32B + + // the distance between neighboring descriptors + static constexpr uint32_t BYTES_PER_DESC = + Gmma_vector_size == Gmma_descriptor_size::ALL ? 0 + : Gmma_trans == Gmma_descriptor_transpose::TRANS + ? Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B ? GMMA_K * BYTES_PER_LEADING_DIM + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? (GMMA_K / 2) * BYTES_PER_LEADING_DIM + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B ? (GMMA_K / 4) * BYTES_PER_LEADING_DIM + : 0 + : Gmma_trans == Gmma_descriptor_transpose::NOTRANS + ? Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B || + Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B + ? BYTES_PER_GMMA_K // 32B + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B ? Cta_tile::M * BYTES_PER_GMMA_K + : 0 + : 0; + + // the distance between neighboring desc without 4LSB + static constexpr uint32_t BYTES_PER_DESC_NO_4LSB = BYTES_PER_DESC >> 4; + + // the distance to travel back from the last desc to the first desc within a group + enum { BYTES_DESC_INC_BOUNDARY_NO_4LSB = BYTES_PER_DESC_NO_4LSB * (Cta_tile::K / GMMA_K - 1) }; + + // set GMMA descriptor mode bits. + static constexpr uint64_t DESCRIPTOR_MODE_IN_BIT_LOCATION = + (static_cast(Gmma_mode) & ((1u << GMMA_DESCRIPTOR_MODE_BITS) - 1)) + << GMMA_DESCRIPTOR_MODE_SHIFT; + + // stride byte offset, bit 32-45, 4LSB not included + // each row is always of 128 byte. 8 rows always. + // divide by 16 since the 4 LSB is not included + static constexpr uint64_t STRIDE_BYTE_OFFSET = + BYTES_PER_LEADING_DIM * + ((Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) ? 8 + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2) / + 16; + // shift 32 bit + static constexpr uint64_t STRIDE_BYTE_OFFSET_IN_BIT_LOCATION = STRIDE_BYTE_OFFSET << 32; + + // leading byte offset, bit 16-29, 4LSB not included + // each row is still 128 byte. + // divide by 16 since the 4 LSB is not included + // for A matrix of TN, and the way we reshape the matrix, LEADING_BYTE_OFFSET is never non-zero + // in the future with different GMMA shape, this might be needed + static constexpr bool LEADING_BYTE_OFFSET_NEEDED = false; + + // the leading byte offset if needed 4LSB not included + static constexpr uint64_t LEADING_BYTE_OFFSET = + Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B + ? BYTES_PER_LEADING_DIM / 16 + : BYTES_PER_LEADING_DIM * + ((Gmma_trans == Gmma_descriptor_transpose::TRANS) ? Cta_tile::K : Cta_tile::M) / 16; + // shift 16 bit + static constexpr uint64_t LEADING_BYTE_OFFSET_IN_BIT_LOCATION = + LEADING_BYTE_OFFSET_NEEDED ? LEADING_BYTE_OFFSET << 16 : 0; + + // ctor + inline __device__ Gmma_descriptor_a() { +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] = 0; + } + +// set bit 62-63 to 1 for SWIZZLE_128B format +// set bit 62-63 to 2 for SWIZZLE_64B format +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= DESCRIPTOR_MODE_IN_BIT_LOCATION; + } + +// stride byte offset, bit 32-45, 4LSB not included +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= STRIDE_BYTE_OFFSET_IN_BIT_LOCATION; + } + + // leading byte offset, bit 16-29, 4LSB not included + if (LEADING_BYTE_OFFSET_NEEDED) { +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= LEADING_BYTE_OFFSET_IN_BIT_LOCATION; + } + } + } + + // update the descriptor based on smem address. Should be called once from prologue. + inline __device__ void set_smem_pointer(uint32_t smem_nvvm_pointer) { + // uint32_t smem_nvvm_pointer = get_smem_pointer(smem); + uint64_t smem_address_bit = static_cast(smem_nvvm_pointer); + + // set base offset, bit 49-61 + uint64_t offset = (smem_address_bit / BYTES_PER_LEADING_DIM) % + ((Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) ? 8 + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2); + uint64_t offset_in_bit_location = offset << 49; +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= offset_in_bit_location; + } + +// start_address, bit 0-13, 4LSB not included (so grab bit 4-17) +// the only bits that is different for each desc of the same obj +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + // for fp16, desc_idx_in_128B should range from 0 to 3 + int desc_idx_in_128B = desc_idx % NUM_DESCRIPTORS_PER_128B_IN_K; + int desc_idx_over_128B = desc_idx / NUM_DESCRIPTORS_PER_128B_IN_K; + + uint64_t smem_address_bit_in_bit_location = + (smem_address_bit + ((GMMA_K * BITS_PER_ELEMENT) / 8) * desc_idx_in_128B + + Cta_tile::M * BYTES_PER_LEADING_DIM * desc_idx_over_128B) + << 46; + + smem_address_bit_in_bit_location = smem_address_bit_in_bit_location >> 50; + desc[desc_idx] |= smem_address_bit_in_bit_location; + } + } + + // get a single desc from the desc group. + inline __device__ uint64_t get_descriptor(int desc_idx) const { + // printf("desc[0] = 0x%lx\n", desc[0]); + return desc[(Gmma_vector_size == Gmma_descriptor_size::ALL) ? desc_idx : 0]; + } + + // get the max descriptor for desc[0] + inline __device__ uint64_t get_max_descriptor_0() const { return max_desc_0; } + + // set a single desc from the desc group. + inline __device__ void set_descriptor(int desc_idx, uint64_t single_desc) { + desc[(Gmma_vector_size == Gmma_descriptor_size::ALL) ? desc_idx : 0] = single_desc; + } + + // set the max descriptor for desc[0]. Should be called once from prologue. + // Should be called with set_smem_pointer() + // This value is needed to "loop back" to the first LDGSTS buffer when appropriate. + inline __device__ void set_max_descriptor_0(int mem_offset_no_4LSB) { + max_desc_0 = desc[0] + mem_offset_no_4LSB; + } + + // for desc group where all desc all allocated, + // increment_single_descriptor() will do nothing. + inline __device__ void increment_single_descriptor(bool last_of_kblock) { + // update smem start address, which is in lower 32bits. + int2& tmp = reinterpret_cast(desc[0]); + if (last_of_kblock == true) { + tmp.x -= BYTES_DESC_INC_BOUNDARY_NO_4LSB; + } else { + tmp.x += BYTES_PER_DESC_NO_4LSB; + } + } + + template + inline __device__ void increment_single_descriptor() { + int2& tmp = reinterpret_cast(desc[0]); + tmp.x += (BYTE_OFFSET >> 4); + } + + private: + // the descriptors, each of 64 bit + uint64_t desc[NUM_DESCRIPTORS]; + // the max desc for desc_idx = 0 + uint64_t max_desc_0; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// for b +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Gmma_descriptor_b { + public: + // The type of the Single Descriptor + using Single_desc = Single_descriptor; + + // Transpose mode. + static constexpr Gmma_descriptor_transpose TRANS_MODE = Gmma_trans; + + // The number of descriptors per 64xNblockxKblock. + static constexpr Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = Gmma_vector_size; + + // Currently the number of descriptors per 64xNblockxKblock is always One + // Historically we have supported more descriptors. But that has proven to + // be less performant as it consumes too many uniform registers. + // During the process of refactoring we have decided to only support allocating + // one desc per 64xNblockxKblock. If needed in the future, we can support + // more desc. + static_assert(Gmma_vector_size == Gmma_descriptor_size::ONE, + "Currently, only Mblock/64 desc is allocated per kgroup\n"); + + // Interleaved Mode is currently not supported. + // static_assert to avoid accidentally instantiate it. + static_assert(Gmma_mode != Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, SWIZZLE_NONE mode is not implemented. \n"); + + // byte per leading dim (column if TN, row if NT), must be 128 + enum { BYTES_PER_LEADING_DIM = 128 }; + + // bytes per element + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // the number of descriptors per kblock is related to GMMA shape and kblock size + enum { + NUM_DESCRIPTORS = (Gmma_vector_size == Gmma_descriptor_size::ALL) ? Cta_tile::K / GMMA_K : 1 + }; + + // the number of descriptors per 128 byte in k dimension (leading dim) + // NUM_DESCRIPTORS_PER_128B_IN_K is really only needed if leading dim is K + enum { + NUM_DESCRIPTORS_PER_128B_IN_K = (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B && + Gmma_trans == Gmma_descriptor_transpose::NOTRANS) + ? BYTES_PER_LEADING_DIM / ((GMMA_K * BITS_PER_ELEMENT) / 8) + : NUM_DESCRIPTORS + }; + + static constexpr uint32_t BYTES_PER_GMMA_K = GMMA_K * BITS_PER_ELEMENT / 8; // 32B + + // the distance between neighboring descriptors + static constexpr uint32_t BYTES_PER_DESC = + Gmma_vector_size == Gmma_descriptor_size::ALL ? 0 + : Gmma_trans == Gmma_descriptor_transpose::TRANS + ? Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B ? GMMA_K * BYTES_PER_LEADING_DIM + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? (GMMA_K / 2) * BYTES_PER_LEADING_DIM + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B ? (GMMA_K / 4) * BYTES_PER_LEADING_DIM + : 0 + : Gmma_trans == Gmma_descriptor_transpose::NOTRANS + ? Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B || + Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B + ? BYTES_PER_GMMA_K // 32B + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_32B ? GMMA_N * BYTES_PER_GMMA_K + : 0 + : 0; + + // the distance between neighboring desc without 4LSB + static constexpr uint32_t BYTES_PER_DESC_NO_4LSB = BYTES_PER_DESC >> 4; + + // the distance to travel back from the last desc to the first desc within a group + enum { BYTES_DESC_INC_BOUNDARY_NO_4LSB = BYTES_PER_DESC_NO_4LSB * (Cta_tile::K / GMMA_K - 1) }; + + // Byte count on tile-K dimension + enum { + RESET_SMEM = ((Gmma_trans == Gmma_descriptor_transpose::NOTRANS) && + (((Cta_tile::K * BITS_PER_ELEMENT) / (8 * BYTES_PER_LEADING_DIM)) > 1)) + ? true + : false + }; + + // Reset bytes per BYTES_PER_LEADING_DIM (128) x tile-N + enum { RESET_BYTES_NO_4LSB = (BYTES_PER_LEADING_DIM * Cta_tile::N) / 16 }; + + // set GMMA descriptor mode bits. + static constexpr uint64_t DESCRIPTOR_MODE_IN_BIT_LOCATION = + (static_cast(Gmma_mode) & ((1u << GMMA_DESCRIPTOR_MODE_BITS) - 1)) + << GMMA_DESCRIPTOR_MODE_SHIFT; + + // stride byte offset, bit 32-45, 4LSB not included + // each column is always of 128 byte. 8 columns always. + // divide by 16 since the 4 LSB is not included + static constexpr uint64_t STRIDE_BYTE_OFFSET = + BYTES_PER_LEADING_DIM * + ((Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) ? 8 + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2) / + 16; + // shift 32 bit + static constexpr uint64_t STRIDE_BYTE_OFFSET_IN_BIT_LOCATION = STRIDE_BYTE_OFFSET << 32; + + // leading byte offset, bit 16-29, 4LSB not included + // each column is still 128 byte. + // divide by 16 since the 4 LSB is not included + // for B matrix of TN, and the way we reshape the matrix, LEADING_BYTE_OFFSET is never non-zero + // in the future with different GMMA shape, this might be needed + static constexpr bool LEADING_BYTE_OFFSET_NEEDED = + (((GMMA_N * BITS_PER_ELEMENT) / 8 > BYTES_PER_LEADING_DIM && + Gmma_trans == Gmma_descriptor_transpose::TRANS) || + GMMA_K == 64) + ? true + : false; + + // the leading byte offset if needed 4LSB not included + static constexpr uint64_t LEADING_BYTE_OFFSET = + GMMA_K == 64 + ? Cta_tile::N * 32 / 16 + : (BYTES_PER_LEADING_DIM * + ((Gmma_trans == Gmma_descriptor_transpose::TRANS) ? Cta_tile::K : Cta_tile::N) / 16); + // shift 16 bit + static constexpr uint64_t LEADING_BYTE_OFFSET_IN_BIT_LOCATION = + LEADING_BYTE_OFFSET_NEEDED ? LEADING_BYTE_OFFSET << 16 : 0; + + // ctor + inline __device__ Gmma_descriptor_b() { +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] = 0; + } + +// set bit 62-63 to 1 for SWIZZLE_128B format +// set bit 62-63 to 2 for SWIZZLE_64B format +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= DESCRIPTOR_MODE_IN_BIT_LOCATION; + } + +// stride byte offset, bit 32-45, 4LSB not included +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= STRIDE_BYTE_OFFSET_IN_BIT_LOCATION; + } + + // leading byte offset, bit 16-29, 4LSB not included + if (LEADING_BYTE_OFFSET_NEEDED) { +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= LEADING_BYTE_OFFSET_IN_BIT_LOCATION; + } + } + } + + // update the descriptor based on smem address. Should be called once from prologue. + inline __device__ void set_smem_pointer(uint32_t smem_nvvm_pointer) { + // uint64_t smem_address_bit = reinterpret_cast(smem); + // uint32_t smem_nvvm_pointer = get_smem_pointer(smem); + uint64_t smem_address_bit = static_cast(smem_nvvm_pointer); + + // set base offset, bit 49-61 + uint64_t offset = (smem_address_bit / BYTES_PER_LEADING_DIM) % + ((Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) ? 8 + : Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2); + uint64_t offset_in_bit_location = offset << 49; +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] |= offset_in_bit_location; + } + +// start_address, bit 0-13, 4LSB not included(so grab bit 4-17) +// the only bits that is different for each desc of the same obj +#pragma unroll + for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + // for fp16, desc_idx_in_128B should range from 0 to 3 + int desc_idx_in_128B = desc_idx % NUM_DESCRIPTORS_PER_128B_IN_K; + int desc_idx_over_128B = desc_idx / NUM_DESCRIPTORS_PER_128B_IN_K; + + uint64_t smem_address_bit_in_bit_location = + (smem_address_bit + ((GMMA_K * BITS_PER_ELEMENT) / 8) * desc_idx_in_128B + + Cta_tile::N * BYTES_PER_LEADING_DIM * desc_idx_over_128B) + << 46; + smem_address_bit_in_bit_location = smem_address_bit_in_bit_location >> 50; + desc[desc_idx] |= smem_address_bit_in_bit_location; + } + } + + // get a single desc from the desc group. + inline __device__ uint64_t get_descriptor(int desc_idx) const { + // if(threadIdx.x == 128) + // printf("desc[0] = 0x%lx\n", desc[0]); + //__syncwarp(); + return desc[(Gmma_vector_size == Gmma_descriptor_size::ALL) ? desc_idx : 0]; + } + + // get the max descriptor for desc[0] + inline __device__ uint64_t get_max_descriptor_0() const { return max_desc_0; } + + // set a single desc from the desc group. + inline __device__ void set_descriptor(int desc_idx, uint64_t single_desc) { + desc[(Gmma_vector_size == Gmma_descriptor_size::ALL) ? desc_idx : 0] = single_desc; + } + + // set the max descriptor for desc[0]. Should be called once from prologue. + // Should be called with set_smem_pointer() + // This value is needed to "loop back" to the first LDGSTS buffer when appropriate. + inline __device__ void set_max_descriptor_0(int mem_offset_no_4LSB) { + max_desc_0 = desc[0] + mem_offset_no_4LSB; + } + + // for desc group where all desc all allocated, + // increment_single_descriptor() will do nothing. + inline __device__ void increment_single_descriptor(bool last_of_kblock) { + // update smem start address, which is in lower 32bits. + int2& tmp = reinterpret_cast(desc[0]); + if (last_of_kblock == true) { + tmp.x -= BYTES_DESC_INC_BOUNDARY_NO_4LSB; + } else { + tmp.x += BYTES_PER_DESC_NO_4LSB; + } + } + + template + inline __device__ void increment_single_descriptor() { + int2& tmp = reinterpret_cast(desc[0]); + tmp.x += (BYTE_OFFSET >> 4); + } + + // for desc group where all desc all allocated, + // increment_single_descriptor() will do nothing. + inline __device__ void increment_single_descriptor(bool last_of_kblock, bool switch_kblock) { + // update smem start address, which is in lower 32bits. + int2& tmp = reinterpret_cast(desc[0]); + if (RESET_SMEM) { + if (switch_kblock) { + tmp.x -= BYTES_PER_DESC_NO_4LSB; + tmp.x += RESET_BYTES_NO_4LSB; + } else { + if (last_of_kblock == true) { + tmp.x -= BYTES_PER_DESC_NO_4LSB; + tmp.x -= RESET_BYTES_NO_4LSB; + } else { + tmp.x += BYTES_PER_DESC_NO_4LSB; + } + } + } else { + if (last_of_kblock == true) { + tmp.x -= BYTES_DESC_INC_BOUNDARY_NO_4LSB; + } else { + tmp.x += BYTES_PER_DESC_NO_4LSB; + } + } + } + + private: + // the descriptors, each of 64 bit + uint64_t desc[NUM_DESCRIPTORS]; + // the max desc for desc_idx = 0 + uint64_t max_desc_0; +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/kernel_traits.h b/csrc/fmha_v2/fmha/hopper/kernel_traits.h new file mode 100644 index 0000000000..edeff1e281 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/kernel_traits.h @@ -0,0 +1,365 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // Instruction traits. + typename Traits_p_, + // Instruction traits. + typename Traits_o_, + // The ldgsts global memory tile for Q, K and V. + template class Gmem_tile_qkv_, + // The tma global memory tile for Q, K and V. + template class Gmem_tile_tma_qkv_, + // The global memory tile for the output. + template class Gmem_tile_o_, + // Sequence length. + int S, + // The hidden dimension. + int D, + // The iteration step of the outer loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The version of the kernel. + int VERSION_, + // The mask version of the kernel, (2 denotes dense mask, 3 denotes causal mask) + int MASK_VERSION_ = 2, + // The flags to control the behaviour of LDGs. + uint32_t FLAGS = 0x8u> +struct FMHA_kernel_traits_hopper { + // The instruction traits for the Q*K product. + using Traits_p = Traits_p_; + + // is Q operand in RF for GMMA? + static constexpr bool GMMA_Q_RF = Traits_p::GMMA_A_RF; + + // is K operand in RF for GMMA? + static constexpr bool GMMA_K_RF = Traits_p::GMMA_B_RF; + + // The instruction traits for P*V product. + using Traits_o = Traits_o_; + + // is S operand in RF for GMMA? + static constexpr bool GMMA_S_RF = Traits_o::GMMA_A_RF; + + // is V operand in RF for GMMA? + static constexpr bool GMMA_V_RF = Traits_o::GMMA_B_RF; + + // The number of warpgroups along M dimension + enum { WARP_GROUP_M = WARPS_M / 4 }; + + // The number of warpgroups along N dimension + enum { WARP_GROUP_N = WARPS_N }; + + // The number of warpgroups along K dimension + enum { WARP_GROUP_K = 1 }; + + // The CTA description for the 1st GEMM. + using Cta_tile_p = + typename Traits_p::template Cta_tile; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = + typename Traits_o::template Cta_tile; + + // The version. + enum { VERSION = VERSION_ }; + + enum { MASK_VERSION = MASK_VERSION_ }; + + // Whether use causal mask or not. + enum { CAUSAL_MASK = MASK_VERSION_ >= 3 }; + + // Whether use the sliding window attention mask or not. + enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4 }; + + // Do we use LDGSTS for Q, K or V. If not, TMA is used! + enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; + + enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; + + enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; + + enum { USE_TMA_Q = !USE_LDGSTS_Q }; + + enum { USE_TMA_K = !USE_LDGSTS_K }; + + enum { USE_TMA_V = !USE_LDGSTS_V }; + + // Do we use one buffer for K and V. + enum { SHARE_SMEM_FOR_K_AND_V = 0 }; + + // Do we use the scale max trick. + enum { USE_SCALE_MAX = 0 }; + + // Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d. + enum { HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = (FLAGS & 0x800) != 0u }; + + // Number of matrix for gmem_tile_qkv + enum { NUM_QKV_MATS = 3 }; + + // The global memory tile to load Q. + // Hopefully we don't need to specialize for Hopper. + using Gmem_tile_ldgsts_q = + Gmem_tile_qkv_; + + // The global memory tile to load Q with TMA. + using Gmem_tile_tma_q = Gmem_tile_tma_qkv_; + + // Do we use ldgsts gmem tile or tma gmem tile? + using Gmem_tile_q = + typename std::conditional_t; + + // 2 buffers for Q + enum { BUFFERS_PER_SMEM_TILE_Q = 2 }; + + // Q is row major + using Q_layout = fmha::Row; + + // We know Q is row-major. So we can also deduce the descriptor mode. + static constexpr fmha::Gmma_descriptor_mode GMMA_DESC_MODE_Q = + Cta_tile_p::K * sizeof(typename Traits_p::A_type) >= 128 + ? fmha::Gmma_descriptor_mode::SWIZZLE_128B + : fmha::Gmma_descriptor_mode::SWIZZLE_64B; + + // The shared memory tile to swizzle Q. + using Smem_tile_ldgsts_q = + fmha::Smem_tile_hopper_a; + + // The shared memory tile to swizzle Q. TODO: need to update to XMMA. + using Smem_tile_tma_q = + fmha::wip::Smem_tile_hopper_a; + + using Smem_tile_q = + typename std::conditional_t; + + // The global memory tile to load K. + // Hopefully we don't need to specialize for hopper. + using Gmem_tile_ldgsts_k = + Gmem_tile_qkv_; + + // The global memory tile to load K with TMA. + using Gmem_tile_tma_k = Gmem_tile_tma_qkv_; + + // Do we use ldgsts gmem tile or tma gmem tile? + using Gmem_tile_k = + typename std::conditional_t; + + // 1 buffers for K + enum { BUFFERS_PER_SMEM_TILE_K = 1 }; + + // K is column major + using K_layout = fmha::Col; + + // We know K is column-major. So we can also deduce the descriptor mode. + static constexpr fmha::Gmma_descriptor_mode GMMA_DESC_MODE_K = + Cta_tile_p::K * sizeof(typename Traits_p::B_type) >= 128 + ? fmha::Gmma_descriptor_mode::SWIZZLE_128B + : fmha::Gmma_descriptor_mode::SWIZZLE_64B; + + // The shared memory tile to swizzle K. + using Smem_tile_ldgsts_k = + fmha::Smem_tile_hopper_b; + + using Smem_tile_tma_k = + fmha::wip::Smem_tile_hopper_b; + + using Smem_tile_k = + typename std::conditional_t; + + // The global memory tile to load V. + using Gmem_tile_ldgsts_v = + Gmem_tile_qkv_; + + // The global memory tile to load V with TMA. + using Gmem_tile_tma_v = Gmem_tile_tma_qkv_; + + // Do we use ldgsts gmem tile or tma gmem tile? + using Gmem_tile_v = + typename std::conditional_t; + + // 1 buffers for V + enum { BUFFERS_PER_SMEM_TILE_V = 1 }; + + // V is row major + using V_layout = fmha::Row; + + // We know V is row marjor. So we can also deduce the descriptor mode. + static constexpr fmha::Gmma_descriptor_mode GMMA_DESC_MODE_V = + Cta_tile_o::N * sizeof(typename Traits_o::B_type) >= 128 + ? fmha::Gmma_descriptor_mode::SWIZZLE_128B + : fmha::Gmma_descriptor_mode::SWIZZLE_64B; + + // The shared memory tile to swizzle V. + using Smem_tile_ldgsts_v = fmha::Smem_tile_v; + + using Smem_tile_tma_v = + fmha::wip::Smem_tile_hopper_b; + + using Smem_tile_v = + typename std::conditional_t; + + // The global memory tile to store O. + // using Gmem_tile_o = fmha::Gmem_tile_o_hopper; + using Gmem_tile_o = fmha::v2::Gmem_tile_o; + + using Smem_tile_o_ = fmha::Smem_tile_o; + static constexpr bool NEEDS_SPLIT_K = WARPS_N > 1; + using Smem_tile_o = + typename std::conditional_t; + + // The amount of shared memory needed to load Q and K. + enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; + + // The extra amount of shared memory needed to load V. + enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K and V.. + enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; + + // The amount of shared memory needed to load Q and store O. + // enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; + // For now let's pretend no smem for O matrix. [Timmy] + enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE }; + + // The amount of over allocated smem to guarantee 1024B alignment. + enum { BYTES_FOR_ALIGNMENT = 1024 }; + + // The size in bytes for each SMEM barrier + enum { BYTES_PER_SMEM_BARRIER = 8 }; + + // The amount of smem used by smem barrier. Only needed if TMA is used. + enum { + BYTES_FOR_SMEM_BARRIER_Q = + USE_LDGSTS_Q == 1 ? 0 : BUFFERS_PER_SMEM_TILE_Q * BYTES_PER_SMEM_BARRIER + }; + + // The amount of smem used by smem barrier. Only needed if TMA is used. + // each smem barrier is 8 bytes, each buffer has 2 barriers + enum { + BYTES_FOR_SMEM_BARRIER_K = + USE_LDGSTS_K == 1 ? 0 : BUFFERS_PER_SMEM_TILE_K * BYTES_PER_SMEM_BARRIER + }; + + // The amount of smem used by smem barrier. Only needed if TMA is used. + // Currently, K and V can share the same barrier. + enum { BYTES_FOR_SMEM_BARRIER_V = 0 }; + + // The amount of smem used by smem barrier. Only needed if TMA is used. + enum { + BYTES_FOR_SMEM_BARRIER = + BYTES_FOR_SMEM_BARRIER_Q + BYTES_FOR_SMEM_BARRIER_K + BYTES_FOR_SMEM_BARRIER_V + }; + + // TODO move those + enum { BYTES_FOR_SOFTMAX = WARPS_N == 1 ? 0 : sizeof(float) * WARPS_N * 64 }; + + enum { + BYTES_PER_SMEM_O = + WARPS_N == 1 ? 0 : WARPS_N * 64 * D * sizeof(typename Traits_o::Epilogue_type) + }; + + static_assert(Smem_tile_o::BYTES_PER_TILE == (int)BYTES_PER_SMEM_O); + + // The amount of shared memory needed for Q, K, V and O. + // TODO double check. + // - For GMMA QKV are always stored in SMEM. + // - Cannot share SMEM K/V + // - O needs to be separate + // enum { BYTES_PER_SMEM = fmha::Max::VALUE + enum { + BYTES_PER_SMEM = BYTES_PER_SMEM_QKV + BYTES_PER_SMEM_O + BYTES_FOR_SOFTMAX + + BYTES_FOR_SMEM_BARRIER + BYTES_FOR_ALIGNMENT + }; + + // The number of threads. + enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; + + // Make sure the number of threads matches both CTAs. + static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); + + // The compute tile for P = Q*K. + using Compute_tile_p = + fmha::Compute_tile_with_gmma; + // The compute tile for O = S*V. + using Compute_tile_o = + fmha::Compute_tile_with_gmma; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The BMM1 instruction traits. + typename Traits_p, + // The BMM2 instruction traits. + typename Traits_o, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The attention mask type (2 denotes dense mask, 3 denotes causal mask). + int MASK_VERSION, + // The flags. + uint32_t FLAGS = 0x8> +using FMHA_kernel_traits_hopper_v2 = + FMHA_kernel_traits_hopper; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/fmha_v2/fmha/hopper/smem_tile.h b/csrc/fmha_v2/fmha/hopper/smem_tile.h new file mode 100644 index 0000000000..b921b48db2 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/smem_tile.h @@ -0,0 +1,2423 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// +/// @brief Interface to Smem tiles for a operator +// HGMMA +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class Gmma_fusion_mode { NO_FUSION, BN_APPLY }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace wip { + +template +struct Smem_tile_hopper_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_hopper_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Col Major. For GMMA, A is from SMEM directly. +// Not implemented, since it is not really needed at the moment. +template +struct Smem_tile_hopper_gmma_col_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Row Major. For GMMA, A is from SMEM directly. +template +struct Smem_tile_hopper_gmma_row_a { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, SWIZZLE_NONE Mode is not implemented.\n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited). + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor. + using Gmma_descriptor = + fmha::Gmma_descriptor_a; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_A }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B + // and SWIZZLE_64B format. + enum { BYTES_PER_ROW = 128 }; + + // the number of rows per one row of K due the the limitation of leading dim size. + enum { NUM_ROWS_PER_K = (Cta_tile::K * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW }; + + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B || + (Cta_tile::K * BYTES_PER_ELEMENT) == 64, + "swizzle_64B row_a is valid if kblock=32\n"); + + // Number of SMEM rows. + enum { + NUM_ROWS = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) + ? (Cta_tile::M * NUM_ROWS_PER_K) + : (Cta_tile::M / 2) + }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_ROWS * BYTES_PER_ROW }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer. + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc. + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a row + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // for swizzle_128B the xor factor is 8 + enum { ROWS_PER_XOR_PATTERN = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) ? 8 : 4 }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = Mma_tile::M_PER_GMMA_GROUP / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 : 2) * + BYTES_PER_ROW + }; + + // The number of STS per row. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // For Hopper, STS_PER_ROW should be 1 (at least for now.) + static_assert(STS_PER_ROW == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_row_a(char* smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)) { + int smem_write_row = tidx / THREADS_PER_ROW; + int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN; + int smem_write_col = 0; + + if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) { + smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; + } else if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B) { + smem_write_col = (tidx % (THREADS_PER_ROW / 2)) ^ + smem_write_xor + ((tidx % THREADS_PER_ROW) / (THREADS_PER_ROW / 2)) * 4; + } + + this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS; + + // That code is expected to trigger the utilization of the URF by the compiler. + this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int row = ii / STS_PER_ROW; + // Assemble the offset. + int offset = smem_write_offset_ + row * ROWS_PER_STS * BYTES_PER_ROW; + // Assemble the final pointer :) + ptrs[ii] = smem_ + offset + smem_write_buffer_; + } + } + + // Store the tile in the shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Move the write offset to next buffer. + inline __device__ void move_next_write_buffer() { + if (BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ += (smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) + ? -BYTES_PER_TILE_INC_BOUNDARY + : BYTES_PER_BUFFER; + } + } + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + // The shared memory pointer. + uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + int smem_read_buffer_; + // The buffer base offset for write. + int smem_write_buffer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Col Major. For GMMA, B is from SMEM directly. +template +struct Smem_tile_hopper_gmma_col_b { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, Interleaved Mode is not implemented.\n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_B }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B and + // SWIZZLE_64B format + enum { BYTES_PER_COLUMN = 128 }; + + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B || + (Cta_tile::K * BYTES_PER_ELEMENT) == 64, + "swizzle_64B col_b is valid if kblock=32\n"); + + // the number of columns per one column of K due the the limitation of leading dim size + enum { + NUM_COLS_PER_K = (Cta_tile::K * BYTES_PER_ELEMENT + BYTES_PER_COLUMN - 1) / BYTES_PER_COLUMN + }; + + // Number of SMEM columns. + enum { + NUM_COLUMNS = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) + ? Cta_tile::N * NUM_COLS_PER_K + : Cta_tile::N / 2 + }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_COLUMNS * BYTES_PER_COLUMN }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc. + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a column. + enum { THREADS_PER_COLUMN = BYTES_PER_COLUMN / BYTES_PER_STS }; + + // The number of columns written with a single STS. + enum { COLUMNS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_COLUMN }; + + // for swizzle_128B the xor factor is 8. + enum { + COLUMNS_PER_XOR_PATTERN = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) ? 8 : 4 + }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = Mma_tile::N_PER_GMMA_GROUP / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 : 2) * + BYTES_PER_COLUMN + }; + + // The number of STS per column. + enum { STS_PER_COLUMN = BYTES_PER_COLUMN / THREADS_PER_COLUMN / BYTES_PER_STS }; + + // For Hopper, STS_PER_COLUMN should be 1 (at least for now.) + static_assert(STS_PER_COLUMN == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_col_b(char* smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)) { + int smem_write_col = tidx / THREADS_PER_COLUMN; + int smem_write_xor = smem_write_col % COLUMNS_PER_XOR_PATTERN; + int smem_write_row = 0; + + if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) { + smem_write_row = (tidx % THREADS_PER_COLUMN) ^ smem_write_xor; + } else if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B) { + smem_write_row = + (tidx % (THREADS_PER_COLUMN / 2)) ^ + smem_write_xor + ((tidx % THREADS_PER_COLUMN) / (THREADS_PER_COLUMN / 2)) * 4; + } + + this->smem_write_offset_ = smem_write_col * BYTES_PER_COLUMN + smem_write_row * BYTES_PER_STS; + // That code is expected to trigger the utilization of the URF by the compiler. + this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int col = ii / STS_PER_COLUMN; + // Assemble the offset. + int offset = smem_write_offset_ + col * COLUMNS_PER_STS * BYTES_PER_COLUMN; + // Assemble the final pointer :) + ptrs[ii] = smem_ + offset + smem_write_buffer_; + } + } + + // Store the tile in the shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Move the write offset to next buffer. + inline __device__ void move_next_write_buffer() { + // if( BUFFERS_PER_TILE > 1 ) { + // this->smem_write_offset_ += ( smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) + // ? -BYTES_PER_TILE_INC_BOUNDARY + // : BYTES_PER_BUFFER; + // } + } + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + // The shared memory pointer. + uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + int smem_read_buffer_; + // The buffer base offset for write. + int smem_write_buffer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Row Major. For GMMA, B is from SMEM directly. +template +struct Smem_tile_hopper_gmma_row_b { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, Interleaved Mode is not implemented.\n"); + + // For SWIZZLE_64B, row b is not needed/implemented + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B, + "Currently, for SWIZZLE_64B mode, row_b is not needed/implemented. \n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_B }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B and + // SWIZZLE_64B format + enum { BYTES_PER_ROW = 128 }; + + // the number of rows per one row of N due the the limitation of leading dim size + enum { NUM_ROWS_PER_N = (Cta_tile::N * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW }; + + // the number of rows per one row of N_PER_GMMA_GROUP + enum { + NUM_ROWS_PER_GMMA_GROUP_N = + (Mma_tile::N_PER_GMMA_GROUP * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW + }; + + // Number of SMEM rows + enum { NUM_ROWS = Cta_tile::K * NUM_ROWS_PER_N }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_ROWS * BYTES_PER_ROW }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a row + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // for swizzle_128B the xor factor is 8 + enum { ROWS_PER_XOR_PATTERN = 8 }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = + Mma_tile::K_PER_GMMA_GROUP * NUM_ROWS_PER_GMMA_GROUP_N * BYTES_PER_ROW + }; + + // The number of STS per ROW. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // For Hopper, STS_PER_ROW should be 1 (at least for now.) + static_assert(STS_PER_ROW == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_row_b(char* smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)) { + int smem_write_row = tidx / THREADS_PER_ROW; + int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN; + int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; + this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS; + // That code is expected to trigger the utilization of the URF by the compiler. + this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int row = ii / STS_PER_ROW; + // Assemble the offset. + int offset = smem_write_offset_ + row * ROWS_PER_STS * BYTES_PER_ROW; + + // Assemble the final pointer :) + ptrs[ii] = smem_ + offset + smem_write_buffer_; + } + } + + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M], uint64_t = 0) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Move the write offset to next buffer. + inline __device__ void move_next_write_buffer() { + // if( BUFFERS_PER_TILE > 1 ) { + // this->smem_write_offset_ += ( smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) + // ? -BYTES_PER_TILE_INC_BOUNDARY + // : BYTES_PER_BUFFER; + // } + } + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + // The shared memory pointer. + uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + int smem_read_buffer_; + // The buffer base offset for write. + int smem_write_buffer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Specialized Interface +// LDGSTS smem tiles. +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Col Major, A coming from SMEM +template +struct Smem_tile_hopper_a + : public Smem_tile_hopper_gmma_col_a { + // The base class. + using Base = Smem_tile_hopper_gmma_col_a; + + // Ctor. + // comment the implementation out as a mark that this is not supported, yet. + // inline __device__ Smem_tile_hopper_a( char *smem, int tidx ) : Base( smem, tidx ) { + //} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Row Major, A coming from SMEM +template +struct Smem_tile_hopper_a + : public Smem_tile_hopper_gmma_row_a { + // The base class. + using Base = Smem_tile_hopper_gmma_row_a; + + // Ctor. + inline __device__ Smem_tile_hopper_a(char* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Col Major, B coming from SMEM +template +struct Smem_tile_hopper_b + : public Smem_tile_hopper_gmma_col_b { + // The base class. + using Base = Smem_tile_hopper_gmma_col_b; + + // Ctor. + inline __device__ Smem_tile_hopper_b(char* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Row Major, B coming from SMEM +template +struct Smem_tile_hopper_b + : public Smem_tile_hopper_gmma_row_b { + // The base class. + using Base = Smem_tile_hopper_gmma_row_b; + + // Ctor. + inline __device__ Smem_tile_hopper_b(char* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Specialized Interface +// TMA smem tiles. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Row Major. For GMMA, A is from SMEM directly. +template +struct Smem_tile_hopper_gmma_tma_row_a { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, SWIZZLE_NONE Mode is not implemented.\n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited). + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor. + using Gmma_descriptor = + fmha::Gmma_descriptor_a; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_A }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B + // and SWIZZLE_64B format. + enum { BYTES_PER_ROW = 128 }; + + // the number of rows per one row of K due the the limitation of leading dim size. + enum { NUM_ROWS_PER_K = (Cta_tile::K * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW }; + + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B || + (Cta_tile::K * BYTES_PER_ELEMENT) == 64, + "swizzle_64B row_a is valid if kblock=32\n"); + + // Number of SMEM rows. + enum { + NUM_ROWS = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) + ? (Cta_tile::M * NUM_ROWS_PER_K) + : (Cta_tile::M / 2) + }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_ROWS * BYTES_PER_ROW }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer. + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc. + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a row + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // for swizzle_128B the xor factor is 8 + enum { ROWS_PER_XOR_PATTERN = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) ? 8 : 4 }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = Mma_tile::M_PER_GMMA_GROUP / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 : 2) * + BYTES_PER_ROW + }; + + // The number of STS per row. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // For Hopper, STS_PER_ROW should be 1 (at least for now.) + static_assert(STS_PER_ROW == 1, ""); + + // Each smem barrier is of 8 bytes + enum { BYTES_PER_SMEM_BARRIER = 8 }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { + BYTES_PER_TILE_INC_BOUNDARY_SMEM_BARRIER = + BYTES_PER_SMEM_BARRIER * BUFFERS_PER_TILE - BYTES_PER_SMEM_BARRIER + }; + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_tma_row_a(char* smem, char* smem_barrier) + : smem_(__nvvm_get_smem_pointer(smem)), + smem_barrier_(__nvvm_get_smem_pointer(smem_barrier)), + smem_write_offset_(0), + smem_barrier_offset_(0) {} + + // Move the write offset to next buffer. + // Also move the smem_barrier. + inline __device__ void move_next_write_buffer() { + if (BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ += (smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) + ? -BYTES_PER_TILE_INC_BOUNDARY + : BYTES_PER_BUFFER; + } + + // also update the smem_barrier. + if (BUFFERS_PER_TILE > 1) { + this->smem_barrier_offset_ += + (smem_barrier_offset_ >= BYTES_PER_TILE_INC_BOUNDARY_SMEM_BARRIER) + ? -BYTES_PER_TILE_INC_BOUNDARY_SMEM_BARRIER + : BYTES_PER_SMEM_BARRIER; + } + } + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + template + inline __device__ void store(cudaTmaDesc const* p_desc, int32_t const (&coord)[DIM], + uint16_t filter_offsets = 0, uint16_t mcast_cta_mask = 0) { + fmha::utmaldg(p_desc, smem_ + smem_write_offset_, + smem_barrier_ + smem_barrier_offset_, coord); + } + + // The shared memory pointer. + uint32_t smem_; + // The barrier in smem. + uint32_t smem_barrier_; + // The write offset. + int smem_write_offset_; + // The smem barrier offset + int smem_barrier_offset_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Col Major. For GMMA, B is from SMEM directly. +template +struct Smem_tile_hopper_gmma_tma_col_b { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, Interleaved Mode is not implemented.\n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_B }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B and + // SWIZZLE_64B format + enum { BYTES_PER_COLUMN = 128 }; + + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B || + (Cta_tile::K * BYTES_PER_ELEMENT) == 64, + "swizzle_64B col_b is valid if kblock=32\n"); + + // the number of columns per one column of K due the the limitation of leading dim size + enum { + NUM_COLS_PER_K = (Cta_tile::K * BYTES_PER_ELEMENT + BYTES_PER_COLUMN - 1) / BYTES_PER_COLUMN + }; + + // Number of SMEM columns. + enum { + NUM_COLUMNS = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) + ? Cta_tile::N * NUM_COLS_PER_K + : Cta_tile::N / 2 + }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_COLUMNS * BYTES_PER_COLUMN }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc. + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a column. + enum { THREADS_PER_COLUMN = BYTES_PER_COLUMN / BYTES_PER_STS }; + + // The number of columns written with a single STS. + enum { COLUMNS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_COLUMN }; + + // for swizzle_128B the xor factor is 8. + enum { + COLUMNS_PER_XOR_PATTERN = (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B) ? 8 : 4 + }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = Mma_tile::N_PER_GMMA_GROUP / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 : 2) * + BYTES_PER_COLUMN + }; + + // The number of STS per column. + enum { STS_PER_COLUMN = BYTES_PER_COLUMN / THREADS_PER_COLUMN / BYTES_PER_STS }; + + // For Hopper, STS_PER_COLUMN should be 1 (at least for now.) + static_assert(STS_PER_COLUMN == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_tma_col_b(char* smem, char* smem_barrier) + : smem_(__nvvm_get_smem_pointer(smem)), + smem_barrier_(__nvvm_get_smem_pointer(smem_barrier)) {} + + // Move the write offset to next buffer. + // Not implemented as it is not needed currently. + inline __device__ void move_next_write_buffer() {} + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + template + inline __device__ void store(cudaTmaDesc const* p_desc, int32_t const (&coord)[DIM], + uint16_t filter_offsets = 0, uint16_t mcast_cta_mask = 0) { + fmha::utmaldg(p_desc, smem_, smem_barrier_, coord); + } + + // The shared memory pointer. + uint32_t smem_; + // The barrier in smem. + uint32_t smem_barrier_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Row Major. For GMMA, B is from SMEM directly. +template +struct Smem_tile_hopper_gmma_tma_row_b { + // Currently Interleaved Mode is not implemented. + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_NONE, + "Currently, Interleaved Mode is not implemented.\n"); + + // For SWIZZLE_64B, row b is not needed/implemented + static_assert(desc_mode != fmha::Gmma_descriptor_mode::SWIZZLE_64B, + "Currently, for SWIZZLE_64B mode, row_b is not needed/implemented. \n"); + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + using Cta_tile_gmma = Cta_tile; + + // the size in bits of each element. + enum { BITS_PER_ELEMENT = Traits::BITS_PER_ELEMENT_B }; + + // the size of bytes of each element. + enum { BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8 }; + + // The size in bytes of a single LDGSTS/STS. + enum { BYTES_PER_STS = 16 }; + + // The number of elements per LDGSTS/STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // SMEM layout for GMMA has a leading dim of exact 128 Byte, at least for SWIZZLE_128B and + // SWIZZLE_64B format + enum { BYTES_PER_ROW = 128 }; + + // the number of rows per one row of N due the the limitation of leading dim size + enum { NUM_ROWS_PER_N = (Cta_tile::N * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW }; + + // the number of rows per one row of N_PER_GMMA_GROUP + enum { + NUM_ROWS_PER_GMMA_GROUP_N = + (Mma_tile::N_PER_GMMA_GROUP * BYTES_PER_ELEMENT + BYTES_PER_ROW - 1) / BYTES_PER_ROW + }; + + // Number of SMEM rows + enum { NUM_ROWS = Cta_tile::K * NUM_ROWS_PER_N }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = NUM_ROWS * BYTES_PER_ROW }; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of threads needed to store a row + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // for swizzle_128B the xor factor is 8 + enum { ROWS_PER_XOR_PATTERN = 8 }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + enum { + GMMA_GROUP_SMEM_DISTANCE = + Mma_tile::K_PER_GMMA_GROUP * NUM_ROWS_PER_GMMA_GROUP_N * BYTES_PER_ROW + }; + + // The number of STS per ROW. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // For Hopper, STS_PER_ROW should be 1 (at least for now.) + static_assert(STS_PER_ROW == 1, ""); + + // Ctor. + inline __device__ Smem_tile_hopper_gmma_tma_row_b(char* smem, char* smem_barrier) + : smem_(__nvvm_get_smem_pointer(smem)), + smem_barrier_(__nvvm_get_smem_pointer(smem_barrier)) {} + + // Move the write offset to next buffer. + // Not implemented since it is not needed at the moment. + inline __device__ void move_next_write_buffer() {} + + inline __device__ void move_next_write_buffer(int) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + template + inline __device__ void store(cudaTmaDesc const* p_desc, int32_t const (&coord)[DIM], + uint16_t filter_offsets = 0, uint16_t mcast_cta_mask = 0) { + fmha::utmaldg(p_desc, smem_, smem_barrier_, coord); + } + + // The shared memory pointer. + uint32_t smem_; + // The barrier in smem. + uint32_t smem_barrier_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// A Row Major, A coming from SMEM +template +struct Smem_tile_hopper_a + : public Smem_tile_hopper_gmma_tma_row_a { + // The base class. + using Base = Smem_tile_hopper_gmma_tma_row_a; + + // Ctor. + inline __device__ Smem_tile_hopper_a(char* smem, char* smem_barrier) : Base(smem, smem_barrier) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Col Major, B coming from SMEM +template +struct Smem_tile_hopper_b + : public Smem_tile_hopper_gmma_tma_col_b { + // The base class. + using Base = Smem_tile_hopper_gmma_tma_col_b; + + // Ctor. + inline __device__ Smem_tile_hopper_b(char* smem, char* smem_barrier) : Base(smem, smem_barrier) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// B Row Major, B coming from SMEM +template +struct Smem_tile_hopper_b + : public Smem_tile_hopper_gmma_tma_row_b { + // The base class. + using Base = Smem_tile_hopper_gmma_tma_row_b; + + // Ctor. + inline __device__ Smem_tile_hopper_b(char* smem, char* smem_barrier) : Base(smem, smem_barrier) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace wip + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits_, + // The description of the tile computed by this CTA. + typename Cta_tile_, + // The layout of the tile. + typename Layout_, + // The number of bytes per STS. + int BYTES_PER_STS_, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // Whether to use TMA. + bool USE_TMA, + // Whether A is coming for RF. + bool GMMA_A_RF = Traits_::GMMA_A_RF> +struct Smem_tile_hopper_a : public fmha::Smem_tile_without_skews< + Cta_tile_, Layout_::COL ? Cta_tile_::K : Cta_tile_::M, + Layout_::COL ? Cta_tile_::M : Cta_tile_::K, + Traits_::BITS_PER_ELEMENT_A, BYTES_PER_STS_, BUFFERS_PER_TILE_, 0, + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 8 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2), + 1, true, USE_TMA, 128 * 8 / Traits_::BITS_PER_ELEMENT_A> { + using Traits = Traits_; + using Cta_tile = Cta_tile_; + // The base class. + using Base = fmha::Smem_tile_without_skews< + Cta_tile, Layout_::COL ? Cta_tile::K : Cta_tile::M, Layout_::COL ? Cta_tile::M : Cta_tile::K, + Traits::BITS_PER_ELEMENT_A, BYTES_PER_STS_, BUFFERS_PER_TILE_, 0, + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 8 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2), + 1, true, USE_TMA, 128 * 8 / Traits::BITS_PER_ELEMENT_A>; + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The layout + using Layout = Layout_; + // The fragment. + using Fragment = fmha::Fragment_a; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_a; + + // the number of columns per one column of M_PER_GMMA_GROUP + enum { + NUM_COLS_PER_GMMA_GROUP_M = + (Mma_tile::M_PER_GMMA_GROUP * Base::BITS_PER_ELEMENT / 8 + Base::BYTES_PER_ROW - 1) / + Base::BYTES_PER_ROW + }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + static constexpr int GMMA_GROUP_SMEM_DISTANCE = + Layout::COL ? (Mma_tile::K_PER_GMMA_GROUP * NUM_COLS_PER_GMMA_GROUP_M * Base::BYTES_PER_ROW * + Cta_tile::WARP_GROUP_M) + : (Mma_tile::M_PER_GMMA_GROUP * Cta_tile::WARP_GROUP_M / + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 2 + : 4) * + Base::BYTES_PER_ROW); + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = Base::BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // Ctor. + inline __device__ Smem_tile_hopper_a(void* smem, int tidx) : Base(smem, tidx) {} + + // set the scale and bias smem pointer + inline __device__ void set_scale_bias_smem_ptr(char* scale_bias_smem_ptr, int tidx, int k) {} + + // Load from shared memory. + template + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {} + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + // Overload set needs to be replicated for compatibility + inline __device__ void move_next_read_buffer(int N) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits_, + // The description of the tile computed by this CTA. + typename Cta_tile_, + // The layout of the tile. + typename Layout_, + // The number of bytes per STS. + int BYTES_PER_STS_, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // USe TMA or not, + bool USE_TMA> +struct Smem_tile_hopper_b + : public fmha::Smem_tile_without_skews< + Cta_tile_, + Layout_::COL ? Cta_tile_::N : Cta_tile_::K, // ROWS + Layout_::COL ? Cta_tile_::K : Cta_tile_::N, // COLS + Traits_::BITS_PER_ELEMENT_B, BYTES_PER_STS_, BUFFERS_PER_TILE_, + 0, // LDS_FAST_PATH + // Determine ROWS_PER_XOR_PATTERN from the swizzle mode: + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 8 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : /* 32B or NONE */ 2), + 1, // COLS_PER_XOR_PATTERN + true, // USE_PREDICATES + USE_TMA, + 128 * 8 / Traits_::BITS_PER_ELEMENT_B // LEAD_DIM_ELEMENTS + > { + using Traits = Traits_; + using Cta_tile = Cta_tile_; + // The base class. + using Base = fmha::Smem_tile_without_skews< + Cta_tile, Layout_::COL ? Cta_tile::N : Cta_tile::K, Layout_::COL ? Cta_tile::K : Cta_tile::N, + Traits::BITS_PER_ELEMENT_B, BYTES_PER_STS_, BUFFERS_PER_TILE_, 0, + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 8 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 4 + : 2), + 1, true, USE_TMA, 128 * 8 / Traits::BITS_PER_ELEMENT_B>; + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The layout + using Layout = Layout_; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The number of desc within a gmma group (kblock limited) + static constexpr fmha::Gmma_descriptor_size GMMA_DESC_SIZE_PER_GROUP = + fmha::Gmma_descriptor_size::ONE; + // The SWIZZLE_128B descriptor + using Gmma_descriptor = + fmha::Gmma_descriptor_b; + + // the number of rows per one row of N_PER_GMMA_GROUP + enum { + NUM_ROWS_PER_GMMA_GROUP_N = + (Mma_tile::N_PER_GMMA_GROUP * Base::BITS_PER_ELEMENT / 8 + Base::BYTES_PER_ROW - 1) / + Base::BYTES_PER_ROW + }; + + // The distance in byte between different GMMA groups (might need multiple due to cta tile size) + // each GMMA group is of size GMMA_M x GMMA_N x Kblock + + // The dimension that we split. + // Add buffers when we have multiple buffers for split head dimensions. + // Split-d smem view (2 split D, and 3 buffers): d0, d0, d0, d1, d1, d1. + static constexpr int GMMA_GROUP_SPLIT_DIM = + Layout::COL ? Mma_tile::N_PER_GMMA_GROUP : (Mma_tile::K_PER_GMMA_GROUP * BUFFERS_PER_TILE_); + + // The split factor. + static constexpr int GMMA_GROUP_SPLIT_FACTOR = + (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_128B ? 1 + : desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B ? 2 + : 4); + + // Make sure the dimension that we split is a multiple of the split factor. + static_assert(GMMA_GROUP_SPLIT_DIM % GMMA_GROUP_SPLIT_FACTOR == 0); + + // The distance between two "groups" in shared memory. + static constexpr int GMMA_GROUP_SMEM_DISTANCE = + GMMA_GROUP_SPLIT_DIM / GMMA_GROUP_SPLIT_FACTOR * Base::BYTES_PER_ROW; + + // the size of one buffer in bytes in shared memory, without the 4 LSB. + // this is needed to increment the GMMA desc to the next buffer + enum { BYTES_PER_BUFFER_NO_4LSB = Base::BYTES_PER_BUFFER / 16 }; + + // this is needed to decrement GMMA desc + enum { + BYTES_PER_BUFFER_INC_BOUNDARY_NO_4LSB = + BYTES_PER_BUFFER_NO_4LSB * BUFFERS_PER_TILE_ - BYTES_PER_BUFFER_NO_4LSB + }; + + // Ctor. + inline __device__ Smem_tile_hopper_b(void* smem, int tidx) : Base(smem, tidx) { + warp_id_ = tidx / 32; + lane_id_ = tidx % 32; + + // each pair of warps transposes 8x8 in place + // each warp responsible for diagonal 4x4s + // calculate index in 8x8 block + block_row_ = lane_id_ / 4; + block_col_ = (lane_id_ % 4) + ((warp_id_ % 2) ^ (block_row_ / 4)) * 4; + + // diagonal 4x4s will 2x conflict for SWIZZLE_32B + // 1 warp per 8x8, 2 4x8 load+store + if (Traits::GMMA_N == 8) { + block_row_ = lane_id_ / 8; + block_col_ = lane_id_ % 8; + } + + // offset when all 4 warps participate in transpose + block_col_offset_ = (warp_id_ / 2) * 8; + } + + int warp_id_, lane_id_; + int block_row_, block_col_, block_col_offset_; + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {} + + // Load from smem, do something (e.g. transpose), then store back to smem + inline __device__ void load_and_store(int ki) { + /* + using B_type = typename Traits::B_type; + + // TODO: move these to B_RF smem tiles + + // 8 channel per group fp16 fprop/dgrad with 64x16x16 gmma + // move 8x8 OOB zeros to right diagonal, 8x8 in-bounds weights on left diagonal + if (Cta_tile::N_PER_GROUP == 8 && Traits::GMMA_N == 16 + && Traits::BITS_PER_ELEMENT_B == 16) { + // just need to swap 2 cores within a single SWIZZLE_32B, one of which is just zero + // 1 LDSM.M88.1 + if (warp_id_ == 0) { + int smem_row_offset = ki * 4 * 128 + 2 * 128; // 4 rows per 16x16, swap the bottom 8x16 + int lds_block_idx = lane_id_ * 2; // ldsm.m88.1 only uses first 8 threads for address + int lds_smem_idx = lds_block_idx ^ (lane_id_ / 4); + + uint32_t data; + uint32_t lds_smem_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + lds_smem_idx * 16; + fmha::ldsm(data, lds_smem_ptr); + + __syncwarp(); + + // move values to adjacent core + fmha::stsm(lds_smem_ptr ^ 16, data); + + // set zeros at previous core + fmha::stsm(lds_smem_ptr, static_cast(0)); + } + } + + // 4 channel per group tf32 fprop with 64x8x8 gmma + // move 4x4 in-bounds weights on left diagonal, OOB zeros everywhere else + if (Cta_tile::N_PER_GROUP == 4 && Traits::GMMA_N == 8 + && Layout::COL && Traits::BITS_PER_ELEMENT_B == 32) { + // just need to swap the bottom 4x8, 1 elt per thread for 1 warp + // 1 lds/sts.32 per thread + if (warp_id_ == 0) { + int smem_row_offset = ki * Base::ROWS_PER_XOR_PATTERN * 128 + 128; + int lds_smem_idx = lane_id_; + uint32_t lds_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + lds_smem_idx * sizeof(B_type); + uint32_t data; + lds(data, lds_ptr); + + __syncwarp(); + + sts(lds_ptr ^ 16, data); + } + } + + // partial transpose of 8xN_PER_GROUP operand for tf32 grouped dgrad + // todo: revise this for tf32 grouped wgrad, move to partial specialization + static constexpr bool IS_TF32_GROUPED_DGRAD = + (Cta_tile::GROUPS_N > 1 && Cta_tile::GROUPS_K > 1 || Cta_tile::N_PER_GROUP == 32) + && Layout::ROW && Traits::BITS_PER_ELEMENT_B == 32; + if (IS_TF32_GROUPED_DGRAD) { + static constexpr int XOR_SCALE = 16 / sizeof(B_type); // 16B swizzle over 4B elements + static constexpr int ROWS_PER_128B = kDivUp( 128, Traits::GMMA_N * sizeof(B_type) ); + + if (Traits::GMMA_N == 8) { + if (warp_id_ == 0) { + + int smem_row_offset = ki * Base::ROWS_PER_XOR_PATTERN * 128; + uint32_t data[2]; + + #pragma unroll + for (int ii = 0; ii < 2; ii++) { + // get index in row-major 8x8 + int lds_block_row = block_row_ + ii * 4; + int lds_block_col = block_col_; + int lds_block_idx = lds_block_row * 8 + lds_block_col; + + // swizzle + int lds_xor_factor = (lds_block_row / ROWS_PER_128B) * XOR_SCALE; + int lds_smem_idx = lds_block_idx ^ lds_xor_factor; + + // Load from smem + uint32_t lds_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + lds_smem_idx * sizeof(B_type); + lds(data[ii], lds_ptr); + } + + __syncwarp(); + + #pragma unroll + for (int ii = 0; ii < 2; ii++) { + // get index in col-major 8x8 + int sts_block_row = block_col_; + int sts_block_col = block_row_ + ii * 4; + if (Cta_tile::N_PER_GROUP == 4 && ii == 1) { + // place 4x4 weights on diagonal for 4-channel tf32 group dgrad + sts_block_row ^= 4; + } + int sts_block_idx = sts_block_row * 8 + sts_block_col; + + // swizzle + int sts_xor_factor = (sts_block_row / ROWS_PER_128B) * XOR_SCALE; + int sts_smem_idx = sts_block_idx ^ sts_xor_factor; + + // store to smem + uint32_t sts_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + sts_smem_idx * sizeof(B_type); + sts(sts_ptr, data[ii]); + } + + } // warp_id == 0 + } else { + // loop over 8x16 blocks + #pragma unroll + for (int ii = 0; ii < kDivUp(Cta_tile::N_PER_GROUP, 16); ii++) { + int smem_row_offset = ki * Base::ROWS_PER_XOR_PATTERN * 128; + + // get index in row-major 8xN_PER_GROUP + int lds_block_row = block_row_; + int lds_block_col = block_col_ + block_col_offset_ + ii * 16; + int lds_block_idx = lds_block_row * Cta_tile::N_PER_GROUP + + lds_block_col; + + // swizzle + int lds_xor_factor = (lds_block_row / ROWS_PER_128B) * XOR_SCALE; + int lds_smem_idx = lds_block_idx ^ lds_xor_factor; + + // Load from smem + uint32_t lds_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + lds_smem_idx * sizeof(B_type); + uint32_t data; + lds(data, lds_ptr); + + __syncwarp(); + + // get index in row-major 8xN_PER_GROUP with 8x8 in-place transposes + int sts_block_row = block_col_; + int sts_block_col = block_row_ + block_col_offset_ + ii * 16; + int sts_block_idx = sts_block_row * Cta_tile::N_PER_GROUP + + sts_block_col; + + // swizzle + int sts_xor_factor = (sts_block_row / ROWS_PER_128B) * XOR_SCALE; + int sts_smem_idx = sts_block_idx ^ sts_xor_factor; + + // store to smem + uint32_t sts_ptr = this->smem_ + this->smem_read_buffer_ + + smem_row_offset + + sts_smem_idx * sizeof(B_type); + sts(sts_ptr, data); + } + } + } + + // make sure sts are visible to gmma + fence_view_async_shared(); + */ + } + + // Move the read offset to next buffer. + inline __device__ void move_next_read_buffer() {} + + // Move the read offset to next buffer. + inline __device__ void move_next_read_buffer(int buffer_id) { + this->smem_read_buffer_ = buffer_id * Base::BYTES_PER_BUFFER; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < // GMMA instruction shape in M dim + int GMMA_M, + // GMMA instruction shape in N dim + int GMMA_N, + // GMMA instruction shape in K dim + int GMMA_K, + // GMMA A operand coming from RF? + bool GMMA_A_RF, + // GMMA B operand coming from RF? + bool GMMA_B_RF, + // The description of the tile computed by this CTA. + typename Cta_tile, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // Use TMA or not, + bool USE_TMA, int BUFFERS_PER_TILE> +struct Smem_tile_v, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, desc_mode, USE_TMA> { + static constexpr bool TRANSPOSE = false; + + using Cta_tile_gmma = Cta_tile; + + using Base = fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_fp16_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, desc_mode, USE_TMA>; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void transpose_tile(int) { + // Transpose is fused into HGMMA. + } + + inline __device__ void transpose_tile(int, uint32_t, uint32_t) { + // Transpose is fused into HGMMA. + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < // GMMA instruction shape in M dim + int GMMA_M, + // GMMA instruction shape in N dim + int GMMA_N, + // GMMA instruction shape in K dim + int GMMA_K, + // GMMA A operand coming from RF? + bool GMMA_A_RF, + // GMMA B operand coming from RF? + bool GMMA_B_RF, + // The description of the tile computed by this CTA. + typename Cta_tile, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // Use TMA or not, + bool USE_TMA, int BUFFERS_PER_TILE> +struct Smem_tile_v, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, // BUFFERS_PER_TILE, + desc_mode, USE_TMA> { + static constexpr bool TRANSPOSE = false; + + using Cta_tile_gmma = Cta_tile; + + using Base = fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_fp32_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, // BUFFERS_PER_TILE, + desc_mode, USE_TMA>; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void transpose_tile(int) { + // Transpose is fused into HGMMA. + } + + inline __device__ void transpose_tile(int, uint32_t, uint32_t) { + // Transpose is fused into HGMMA. + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < // GMMA instruction shape in M dim + int GMMA_M, + // GMMA instruction shape in N dim + int GMMA_N, + // GMMA instruction shape in K dim + int GMMA_K, + // GMMA A operand coming from RF? + bool GMMA_A_RF, + // GMMA B operand coming from RF? + bool GMMA_B_RF, + // The description of the tile computed by this CTA. + typename Cta_tile, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // Use TMA or not, + bool USE_TMA, int BUFFERS_PER_TILE> +struct Smem_tile_v, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, // BUFFERS_PER_TILE, + desc_mode, USE_TMA> { + static constexpr bool TRANSPOSE = false; + + using Cta_tile_gmma = Cta_tile; + + using Base = fmha::Smem_tile_hopper_b< + fmha::Hopper_hgmma_bf16_traits, Cta_tile, + fmha::Row, + 16, // BYTES_PER_STS + BUFFERS_PER_TILE, // BUFFERS_PER_TILE, + desc_mode, USE_TMA>; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void transpose_tile(int) { + // Transpose is fused into HGMMA. + } + + inline __device__ void transpose_tile(int, uint32_t, uint32_t) { + // Transpose is fused into HGMMA. + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Transposer {}; + +template +struct Transposer { + static_assert(Cta_tile::K % 128 == 0); + + enum { + WARPS_M = Cta_tile::WARPS_M, + WARPS_N = Cta_tile::WARPS_N, + WARPS_K = Cta_tile::WARPS_K, + }; + + enum { + WARPS_4x1x1 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1), + WARPS_4x1x2 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2), + }; + + enum { BYTES_PER_LDS = 16 }; + + enum { BYTES_PER_ROW = 128 }; + + // D=64 and 4 warps. + // Per warp we load 32 rows x 16 columns with LDSM.Tx4, 128 rows per CTA. + enum { S = Cta_tile::K >= 128 ? 128 : Cta_tile::K }; // The sequence length. + + enum { D = Cta_tile::N >= 128 ? 128 : Cta_tile::N }; // The head dimension. + + // static_assert(S % 128 == 0); + static_assert(WARPS_4x1x1 || WARPS_4x1x2); + static_assert(D % (BYTES_PER_LDS * WARPS_K) == 0); + + enum { ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING = 128 }; // LDSMx4 + + enum { ROW_PACKING = BYTES_PER_ROW / (D * sizeof(typename Traits::B_type)) }; + + enum { ROWS_PER_LDSM_PER_CTA = ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING / ROW_PACKING }; + + enum { ROWS_PER_XOR_PATTERN = fmha::Rows_per_xor_pattern_ampere_b::VALUE }; + + static_assert(ROWS_PER_XOR_PATTERN == 8); + + // The number of loads in K dimension. + enum { K = S / ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING }; + + // static_assert(K * ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING == S); + // static_assert(K == 3); + // The number of loads in the D dimension. + enum { N = D / (BYTES_PER_LDS * WARPS_K) }; // 16 bytes per load + + static_assert(N * BYTES_PER_LDS * WARPS_K == D); + + uint4 regs_[UNROLL_N][K]; + + uint32_t read_offset_; + uint32_t write_offset_; + uint32_t smem_read_loc_; + uint32_t smem_write_loc_; + + inline __device__ Transposer(int tidx) { + int read_row, read_col; + + if (WARPS_4x1x1 && N == 8) { // D=128, 1 warp in N + read_row = (tidx & 0x7f); + read_col = (tidx & 0x07); + } else if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N + read_row = (tidx & 0xe0) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + } else if (WARPS_4x1x1 && N == 2) { // D=32, 1 warp in N + read_row = (tidx & 0x60) / 4 + (tidx & 0x1c) / 4; + read_col = (tidx & 0x03) * 2; + read_col ^= (read_row & 0x01); + } else if (WARPS_4x1x2 && N == 4) { // D=128, 2 warps in N + read_row = (tidx & 0x7f); + read_col = (tidx & 0x07); + // For two warpgroups we do two steps in N at once. + read_col ^= (tidx & 0x80) / 128; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N + read_row = (tidx & 0x60) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + // For two warpgroups we do two steps in N at once. + read_col ^= (tidx & 0x80) / 128; + } else if (WARPS_4x1x2 && N == 1) { // D=32, 2 warps in N + read_row = (tidx & 0x60) / 4 + (tidx & 0x1c) / 4; + read_col = (tidx & 0x03) * 2; + read_col ^= (read_row & 0x01); + // For two warpgroups we do two steps in N at once. + read_col ^= (tidx & 0x80) / 128; + } else { + assert(false); + } + + read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + int write_row, write_col; + if (WARPS_4x1x1) { // swizzle_128byte + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; + } else if (WARPS_4x1x2) { + // Same as above, with second warp group writing next 16 rows. + write_row = (tidx & 0x80) / 8 + (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; + } else { + assert(false); + } + + write_col ^= (write_row & 0x07); + + write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_LDS; + } + + inline __device__ void transpose(int tidx, uint32_t smem) { transpose_(tidx, smem, smem); } + + template + inline __device__ void transpose_(uint32_t smem_src, uint32_t smem_dst) { +#pragma unroll + for (int n_begin = 0; n_begin < N; n_begin += UNROLL_N) { + transpose_ldmatrix(n_begin, smem_src); + transpose_stmatrix(n_begin, smem_dst); + } + } + + inline __device__ void transpose_ldmatrix(int n_begin, uint32_t smem_src) { + static_assert(N % UNROLL_N == 0, ""); + + uint4 tmp[UNROLL_N][K]; + if (n_begin == 0) { + smem_read_loc_ = smem_src + read_offset_; + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { // 2 + fmha::ldsmt(tmp[nii][ki], smem_read_loc_ + ki * ROWS_PER_LDSM_PER_CTA * BYTES_PER_ROW); + } + + if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N + smem_read_loc_ ^= (ni % 2 == 0 ? 1 : 3) * 16; + } else if (WARPS_4x1x1 && N == 2) { // D=32, 1 warp in N + smem_read_loc_ ^= 16; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N + smem_read_loc_ ^= 32; + } else if (WARPS_4x1x2 && N == 4) { // D=128, 2 warps in N + smem_read_loc_ ^= (ni % 2 == 0 ? 1 : 3) * 32; + } else if (WARPS_4x1x1 && N == 8) { // D=128, 1 warp in N + smem_read_loc_ ^= ((ni % 4 == 3) ? 7 : (ni % 2 == 1 ? 3 : 1)) * 16; + } else if (N != 1) { + assert(false); + } + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::swizzle_rows(regs_[nii][ki].x, regs_[nii][ki].z, tmp[nii][ki].x, + tmp[nii][ki].y); // PRMT 0+1 + fmha::swizzle_rows(regs_[nii][ki].y, regs_[nii][ki].w, tmp[nii][ki].z, + tmp[nii][ki].w); // PRMT 2+3 + } + } + } + + template + inline __device__ void transpose_stmatrix(int n_begin, uint32_t smem_dst) { + // After LDSM.Tx4 registers hold 2x2 elts: + // [00, 01] + // [10, 11] + // With row offsets + // x: + 0 + // y: + 8 + // z: +16 (g) + // w: +24 (o) + // + // After PRMT 0, the : + // [00, 01] [80, 81] => x: [00, 10, 80, 90], i.e. col 0 + // [10, 11] [90, 91] => z: [01, 11, 81, 91], i.e. col 1 + // + // [g0, g1] [o0, o1] => y: [g0, h0, o0, p0], i.e. col 0 + // [h0, h1] [p0, p1] => w: [g1, h1, o1, p1], i.e. col 1 + // + // Therefore, when looking at the transpose, quad q holds cols 2 * q + [0, 1], i.e. + // - quad 0 holds cols 0, 1 + // - quad 1 holds cols 2, 3 + // - etc. + // + // This fits with the accumulator layout, since N strides in steps of 8 per thread. + + if (SYNC) { // needed if src and dst are the same. + __syncthreads(); // LDSM.T done. We should now have a D x S tile in registers. SMEM can be + // written. + } + + if (n_begin == 0) { + smem_write_loc_ = smem_dst + write_offset_; + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::stsm(smem_write_loc_ + ki * BYTES_PER_ROW * D, regs_[nii][ki]); + } + if (WARPS_4x1x1) { // D=64, 1 warp in N. + smem_write_loc_ += 16 * BYTES_PER_ROW; + } else if (WARPS_4x1x2) { // D=64, 2 warps in N. + smem_write_loc_ += 32 * BYTES_PER_ROW; + } else { + assert(false); + } + } + } +}; + +template +struct Transposer { + static_assert(Cta_tile::K % 64 == 0); + + enum { + WARPS_M = Cta_tile::WARPS_M, + WARPS_N = Cta_tile::WARPS_N, + WARPS_K = Cta_tile::WARPS_K, + }; + + enum { + WARPS_4x1x1 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1), + WARPS_4x1x2 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2), + }; + + enum { BYTES_PER_LDS = 16 }; + + // D=64 and 4 warps. + // Per warp we load 32 rows x 16 columns with LDSM.Tx4, 128 rows per CTA. + enum { S = Cta_tile::K >= 128 ? 128 : Cta_tile::K }; // The sequence length. + + enum { D = Cta_tile::N >= 128 ? 128 : Cta_tile::N }; // The head dimension. + + static_assert(S % 64 == 0); + static_assert(WARPS_4x1x1); + static_assert(D % 32 == 0); + + static_assert(S == 64 && D == 128); + + // Two warps in S dim. + enum { ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING = 64 }; // LDSMx4 + + enum { BYTES_PER_ROW = 128 }; + + enum { ROW_PACKING = Div_up::VALUE }; + + enum { + ROWS_PER_LDSM_PER_CTA = ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING / ROW_PACKING + }; // due to row_packing + + // The number of loads in K dimension. + enum { K = S / ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING }; + + // The number of loads in the D dimension. Use two warps in D dim. + enum { N = D / 32 }; + + uint4 regs_[UNROLL_N][K]; + + uint32_t read_offset_; + uint32_t write_offset_; + uint32_t smem_read_loc_; + uint32_t smem_write_loc_; + + inline __device__ Transposer(int tidx) { + int read_row, read_col; + + if (WARPS_4x1x1 && N == 1) { // D=32, 2 warps in N + read_row = (tidx & 0x20) / 4 + (tidx & 0x1c) / 4; + read_col = (tidx & 0x03) * 2; + read_col ^= (read_row & 0x01); + read_col ^= ((tidx & 0x40) / 64); + } else if (WARPS_4x1x1 && N == 2) { // D=64, 2 warps in N + read_row = (tidx & 0x20) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + read_col ^= ((tidx & 0x40) / 64); + } else if (WARPS_4x1x1 && N == 4) { // D=128, 2 warps in N + read_row = (tidx & 0x3f); + read_col = (tidx & 0x07); + read_col ^= ((tidx & 0x40) / 64); + } else { + assert(false); + } + + read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // static_assert(ROWS_PER_LDSM_PER_CTA == 32); + // constexpr int ROWS_PER_XOR_PATTERN = 4; + // constexpr int ROWS_PER_XOR_PATTERN = fmha::Rows_per_xor_pattern_ampere_b::VALUE; + + int row, col; + if (WARPS_4x1x1) { + row = (tidx & 0x40) / 4 + (tidx & 0x10) / 2 + (tidx & 0x07); + col = (tidx & 0x20) / 16 + (tidx & 0x08) / 8; + col = col + (row % 2) * 4; + row = row / 2; + col = col ^ (row % 4); + } else { + assert(false); + } + write_offset_ = row * BYTES_PER_ROW + col * BYTES_PER_LDS; + }; + + inline __device__ void transpose(int tidx, uint32_t smem) { transpose_(tidx, smem, smem); } + + template + inline __device__ void transpose_(uint32_t smem_src, uint32_t smem_dst) { +#pragma unroll + for (int n_begin = 0; n_begin < N; n_begin += UNROLL_N) { + transpose_ldmatrix(n_begin, smem_src); + transpose_stmatrix(n_begin, smem_dst); + } + } + + inline __device__ void transpose_ldmatrix(int n_begin, uint32_t smem_src) { + static_assert(N % UNROLL_N == 0, ""); + + uint4 tmp[UNROLL_N][K]; + if (n_begin == 0) { + smem_read_loc_ = smem_src + read_offset_; + } +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { +#pragma unroll + for (int ki = 0; ki < K; ki++) { + int const nii = ni - n_begin; + fmha::ldsmt(tmp[ni][ki], smem_read_loc_ + ki * ROWS_PER_LDSM_PER_CTA * BYTES_PER_ROW); + } + + if (WARPS_4x1x1 && N == 2) { // D=64, 2 warps in N + smem_read_loc_ ^= 32; + } else if (WARPS_4x1x1 && N == 4) { // D=128, 2 warps in N + smem_read_loc_ ^= (ni % 2 == 1 ? 3 * 32 : 32); + } else if (N != 1) { + assert(false); + } + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::swizzle_rows(regs_[nii][ki].x, regs_[nii][ki].z, tmp[nii][ki].x, + tmp[nii][ki].y); // PRMT 0+1 + fmha::swizzle_rows(regs_[nii][ki].y, regs_[nii][ki].w, tmp[nii][ki].z, + tmp[nii][ki].w); // PRMT 2+3 + } + } + } + + template + inline __device__ void transpose_stmatrix(int n_begin, uint32_t smem_dst) { + // After LDSM.Tx4 registers hold 2x2 elts: + // [00, 01] + // [10, 11] + // With row offsets + // x: + 0 + // y: + 8 + // z: +16 (g) + // w: +24 (o) + // + // After PRMT 0, the : + // [00, 01] [80, 81] => x: [00, 10, 80, 90], i.e. col 0 + // [10, 11] [90, 91] => z: [01, 11, 81, 91], i.e. col 1 + // + // [g0, g1] [o0, o1] => y: [g0, h0, o0, p0], i.e. col 0 + // [h0, h1] [p0, p1] => w: [g1, h1, o1, p1], i.e. col 1 + // + // Therefore, when looking at the transpose, quad q holds cols 2 * q + [0, 1], i.e. + // - quad 0 holds cols 0, 1 + // - quad 1 holds cols 2, 3 + // - etc. + // + // This fits with the accumulator layout, since N strides in steps of 8 per thread. + + if (SYNC) { + __syncthreads(); // LDSM.T done. We should now have a D x S tile in registers. SMEM can be + // written. + } + + if (n_begin == 0) { + smem_write_loc_ = smem_dst + write_offset_; + } + +#pragma unroll + for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { + int const nii = ni - n_begin; +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::stsm(smem_write_loc_ + ki * BYTES_PER_ROW * D / 2, regs_[nii][ki]); + } + if (WARPS_4x1x1) { // D=64, 1 warp in N. + smem_write_loc_ += 16 * BYTES_PER_ROW; + } else { + assert(false); + } + } + } +}; + +template < + // The instruction traits. + typename Traits, + // The Cta_tile. + typename Cta_tile, + // The number of buffers. + int BUFFERS_PER_TILE, + // GMMA descriptor mode + fmha::Gmma_descriptor_mode desc_mode, + // USe TMA or not, + bool USE_TMA> +struct Smem_tile_v_gmma { + static_assert(sizeof(typename Traits::B_type) == 1); + + // K is the sequence length dimension (128 for GMMA) + enum { K_ = Cta_tile::K % 128 == 0 ? 128 : 64 }; + + static_assert(Cta_tile::K % K_ == 0); + + // static_assert(Cta_tile::N == 128); + // static_assert(K_ == 128); + // static_assert(BUFFERS_PER_TILE == 2); + + using Cta_tile_gmma_ = + typename Traits::template Cta_tile; + + // TODO Swizzle_32B? + static constexpr fmha::Gmma_descriptor_mode GMMA_DESC_MODE_V = + Cta_tile_gmma_::K * sizeof(typename Traits::B_type) >= 128 + ? fmha::Gmma_descriptor_mode::SWIZZLE_128B + : fmha::Gmma_descriptor_mode::SWIZZLE_64B; + + static_assert( + (Cta_tile::K % 128 == 0 && GMMA_DESC_MODE_V == fmha::Gmma_descriptor_mode::SWIZZLE_128B) || + (Cta_tile::K % 64 == 0 && GMMA_DESC_MODE_V == fmha::Gmma_descriptor_mode::SWIZZLE_64B)); + + enum { NUM_KGROUPS = Cta_tile::K / Cta_tile_gmma_::K }; + + static_assert(NUM_KGROUPS * Cta_tile_gmma_::K == Cta_tile::K); + + enum { BYTES_PER_STS = 16 }; + + // The compute tile only requires static information from Smem_tile_v and accesses SMEM directly + // through GMMA. Hence, we declare a SxD column major matrix in SMEM and have to make sure at + // runtime that the data is transposed. Note that for K > 128, we are using two buffers per tile, + // which we have to fill accordingly. + using Base_ = fmha::Smem_tile_hopper_b; + + // Split D or not, which influences the GMMA_GROUP_SMEM_DISTANCE, and BYTES_PER_BUFFER_NO_4LSB. + // Split-d smem view (2 split D, and 3 buffers): d0, d0, d0, d1, d1, d1. + // The group distance would be number_of_buffers * buffer_size. + // The buffer size is the size for split-d. + static constexpr size_t GMMA_GROUP_SMEM_DISTANCE = + Base_::GMMA_GROUP_SMEM_DISTANCE * BUFFERS_PER_TILE; + static constexpr size_t BYTES_PER_BUFFER_NO_4LSB = Base_::BYTES_PER_BUFFER_NO_4LSB; + + using Gmma_descriptor = typename Base_::Gmma_descriptor; + + struct Base : public Base_ { + using Transposer = Transposer; + static_assert(USE_TMA == false); + static constexpr bool TRANSPOSE = true; + + enum { NUM_KGROUPS = Cta_tile::K / Cta_tile_gmma_::K }; + + enum { ROWS_PER_XOR_PATTERN = fmha::Rows_per_xor_pattern_ampere_b::VALUE }; + + using Descriptor = typename Base_::Gmma_descriptor; + + // Delegate all the stores to the Row-Major Smem_tile. + using Store_delegate = Smem_tile_without_skews; + + using Store_type = typename Store_delegate::Store_type; + + enum { S = Cta_tile::K }; + + // static_assert(Descriptor::BYTES_PER_LEADING_DIM == 128); + // static_assert(Descriptor::STRIDE_BYTE_OFFSET == K_ * 8 / 16); // 128 * 8 / 16 + // static_assert(Descriptor::TRANS_MODE == fmha::Gmma_descriptor_transpose::NOTRANS); + // static_assert(Base::BYTES_PER_TILE == S * 64); + // static_assert(!Descriptor::LEADING_BYTE_OFFSET_NEEDED); + // static_assert(Descriptor::LEADING_BYTE_OFFSET == 128 * 64 / 16); + // static_assert(Descriptor::BYTES_PER_DESC_NO_4LSB == 32 * 1 / 16); + // static_assert(Descriptor::BYTES_DESC_INC_BOUNDARY_NO_4LSB == (K_ / 32 - 1) * 2); + // static_assert(Base::BYTES_PER_BUFFER_NO_4LSB == K_ * 64 / 16); + // static_assert(Base::GMMA_GROUP_SMEM_DISTANCE == 128 * 128 * 2); + // static_assert(Base::BYTES_PER_BUFFER_NO_4LSB == 128 * 128); + + // static_assert(Store_delegate::N_WITH_PADDING == 64); + // static_assert(Store_delegate::ROWS_PER_XOR_PATTERN == 4); + // static_assert(Store_delegate::BYTES_PER_ROW_BEFORE_PACKING == 64); + // static_assert(Store_delegate::ROWS == S / 2); + // static_assert(Store_delegate::BYTES_PER_ROW == 128); + + // Number of rows a warp loads per LDSMx4 + enum { ROWS_PER_LDSM = 4 * 8 }; + + enum { ROWS_PER_LDSM_PER_CTA = ROWS_PER_LDSM * Cta_tile::WARPS_M }; + + static_assert(Cta_tile::WARPS_M == 4); + + enum { LDSMS = Cta_tile::K / ROWS_PER_LDSM_PER_CTA }; + + // TODO we're assigning all rows loaded by a warp group (128 per CTA) to the K dimension. + // This only works for K a multiple of 128. + // For S=192, we want 3 blocks of 64xD. + // static_assert(LDSMS * ROWS_PER_LDSM_PER_CTA == Cta_tile::K); + + static_assert(LDSMS == S / 128); + + enum { BYTES_PER_LDS = 16 }; + + enum { BYTES_PER_ROW = Store_delegate::BYTES_PER_ROW }; + + enum { + WARPS_M = Cta_tile::WARPS_M, + WARPS_N = Cta_tile::WARPS_N, + WARPS_K = Cta_tile::WARPS_K, + }; + + enum { + WARPS_4x1x1 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1), + WARPS_4x1x2 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2), + }; + + inline __device__ Base(void* smem, int tidx) + : Base_(smem, tidx), delegate(smem, tidx), transposer(tidx) {} + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N]) { + uint32_t smem_ptrs[N]; + delegate.compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N], uint32_t (&preds)[M]) { + uint32_t smem_ptrs[N]; + delegate.compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N], uint32_t preds) { + delegate.store(data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M]) { + uint32_t smem_ptrs[N]; + delegate.compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { + uint32_t tmp[1] = {preds}; + delegate.store(gmem_ptrs, tmp); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds) { + uint32_t tmp[1] = {preds}; + delegate.store(gmem_ptrs, tmp); + } + + // Initial offset (via tidx) has been moved to ctor + inline __device__ void transpose_tile(int /* tidx */) { transposer.transpose(0, this->smem_); } + + template + inline __device__ void transpose_tile(uint32_t smem_src, uint32_t smem_dst) { + transposer.template transpose_(smem_src, smem_dst); + } + + inline __device__ void transpose_tile_ldmatrix(int, uint32_t smem) { + transposer.transpose_ldmatrix(0, smem); + } + + inline __device__ void transpose_tile_stmatrix(int, uint32_t smem) { + transposer.template transpose_stmatrix(0, smem); + } + + inline __device__ void transpose_tile_128(int tidx) { + // D=64 and 4 warps. + // Per warp we load 32 rows x 16 columns with LDSM.Tx4, 128 rows per CTA. + constexpr int S = Cta_tile::K; // The sequence length. + constexpr int D = Cta_tile::N; // The head dimension. + // static_assert(S == 256); + static_assert(D == 64); + // static_assert(S % 128 == 0); + static_assert(WARPS_4x1x1 || WARPS_4x1x2); + static_assert(D % (16 * WARPS_K) == 0); + + constexpr int ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING = 128; // LDSMx4 + constexpr int BYTES_PER_ROW = 128; + constexpr int ROW_PACKING = BYTES_PER_ROW / (D * sizeof(Traits::B_type)); + + // The number of loads in K dimension. + constexpr int K = S / ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING; + // static_assert(K * ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING == S); + // static_assert(K == 3); + // The number of loads in the D dimension. + constexpr int N = D / (16 * WARPS_K); + static_assert(N * 16 * WARPS_K == D); + + int read_row, read_col; + + if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N + read_row = (tidx & 0xe0) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N + read_row = (tidx & 0x60) / 2 + (tidx & 0x1e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + // For two warpgroups we do two steps in N at once. + read_col ^= (tidx & 0x80) / 128; + } else { + assert(false); + } + + uint32_t offset = read_row * BYTES_PER_ROW + read_col * 16; + + constexpr int ROWS_PER_LDSM_PER_CTA = + ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING / ROW_PACKING; // due to row_packing + + uint4 tmp[N][K]; + uint32_t smem_tmp = this->smem_; //__nvvm_get_smem_pointer(v_smem_) ; + uint32_t smem_loc = smem_tmp + offset; + +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::ldsmt(tmp[ni][ki], smem_loc + ki * ROWS_PER_LDSM_PER_CTA * BYTES_PER_ROW); + } + + if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N + smem_loc ^= (ni % 2 == 0 ? 1 : 3) * 16; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N + smem_loc ^= 32; + } else { + assert(false); + } + } + + uint4 regs[N][K]; + +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::swizzle_rows(regs[ni][ki].x, regs[ni][ki].z, tmp[ni][ki].x, + tmp[ni][ki].y); // PRMT 0+1 + fmha::swizzle_rows(regs[ni][ki].y, regs[ni][ki].w, tmp[ni][ki].z, + tmp[ni][ki].w); // PRMT 2+3 + } + } + + // After LDSM.Tx4 registers hold 2x2 elts: + // [00, 01] + // [10, 11] + // With row offsets + // x: + 0 + // y: + 8 + // z: +16 (g) + // w: +24 (o) + // + // After PRMT 0, the : + // [00, 01] [80, 81] => x: [00, 10, 80, 90], i.e. col 0 + // [10, 11] [90, 91] => z: [01, 11, 81, 91], i.e. col 1 + // + // [g0, g1] [o0, o1] => y: [g0, h0, o0, p0], i.e. col 0 + // [h0, h1] [p0, p1] => w: [g1, h1, o1, p1], i.e. col 1 + // + // Therefore, when looking at the transpose, quad q holds cols 2 * q + [0, 1], i.e. + // - quad 0 holds cols 0, 1 + // - quad 1 holds cols 2, 3 + // - etc. + // + // This fits with the accumulator layout, since N strides in steps of 8 per thread. + + __syncthreads(); // LDSM.T done. We should now have a D x S tile in registers. SMEM can be + // written. + constexpr int ROWS_PER_XOR_PATTERN = fmha::Rows_per_xor_pattern_ampere_b::VALUE; + static_assert(ROWS_PER_XOR_PATTERN == 8); + + int row, col; + if (WARPS_4x1x1) { + row = (tidx & 0x10) / 2 + (tidx & 0x07); + col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; + } else if (WARPS_4x1x2) { + // Same as above, with second warp group writing next 16 rows. + row = (tidx & 0x80) / 8 + (tidx & 0x10) / 2 + (tidx & 0x07); + col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; + } else { + assert(false); + } + col ^= (row & 0x07); + int dst = smem_tmp + row * BYTES_PER_ROW + col * BYTES_PER_LDS; + +#pragma unroll + for (int ni = 0; ni < N; ni++) { +#pragma unroll + for (int ki = 0; ki < K; ki++) { + fmha::stsm(dst + ki * BYTES_PER_ROW * D, regs[ni][ki]); + } + if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N. + dst += 16 * BYTES_PER_ROW; + } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N. + dst += 32 * BYTES_PER_ROW; + } else { + assert(false); + } + } + } + + Store_delegate delegate; + Transposer transposer; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v< + fmha::Hopper_qgmma_e4m3_fp32_traits, Cta_tile, + BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public Smem_tile_v_gmma< + fmha::Hopper_qgmma_e4m3_fp32_traits, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA>::Base { + using Traits = fmha::Hopper_qgmma_e4m3_fp32_traits; + + using Base = + typename fmha::Smem_tile_v_gmma::Base; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v< + fmha::Hopper_igmma_int8_int32_traits, Cta_tile, + BUFFERS_PER_TILE, desc_mode, USE_TMA> + : public Smem_tile_v_gmma< + fmha::Hopper_igmma_int8_int32_traits, + Cta_tile, BUFFERS_PER_TILE, desc_mode, USE_TMA>::Base { + using Traits = fmha::Hopper_igmma_int8_int32_traits; + + using Base = + typename fmha::Smem_tile_v_gmma::Base; + + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/smem_tile_o.h b/csrc/fmha_v2/fmha/hopper/smem_tile_o.h new file mode 100644 index 0000000000..cd499a5f39 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/smem_tile_o.h @@ -0,0 +1,325 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Smem_tile_o_dummy { + enum { BYTES_PER_TILE = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o_gmma_32bit_8bit : public Smem_tile_o_base_8bit_mma { + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + using Mma_tile = typename Base::Mma_tile; + using Accumulator = typename Base::Accumulator; + + enum { + BYTES_PER_ROW = Base::BYTES_PER_ROW, + BYTES_PER_ROW_WITH_PACKING = Base::BYTES_PER_ROW_WITH_PACKING, + LOOPS = Base::LOOPS, + LDS_PER_LOOP = Base::LDS_PER_LOOP, + ROWS_PER_LDS = Base::ROWS_PER_LDS, + HAS_INCOMPLETE_LDS = Base::HAS_INCOMPLETE_LDS, + }; + + // Ctor. + inline __device__ Smem_tile_o_gmma_32bit_8bit(void* smem, int tidx) : Base(smem, tidx) {} + + // Store the accumulators. + inline __device__ void store(Accumulator const (&acc)[1][1], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + + static_assert(M_PER_MMA == 64); + static_assert(Base::WARPS_4x1x2); + + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + static_assert(MMAS_M_PER_LOOP == 1); + static_assert(Mma_tile::MMAS_N == 1); + static_assert(Mma_tile::CORES_N == 8); + static_assert(Accumulator::NUM_REGS == Mma_tile::CORES_N / 2 * 8); + static_assert(BYTES_PER_ROW == 64 * 4); + static_assert(Cta_tile::WARPS_K == 2); + + static_assert(Mma_tile::CORES_N / 2 == 4); + +#pragma unroll + for (int ni = 0; ni < Mma_tile::CORES_N / 2; ++ni) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + uint4 row_0; + row_0.x = acc[0][0].reg(ni * 8 + 0); // Even + row_0.y = acc[0][0].reg(ni * 8 + 4); // Odd + row_0.z = acc[0][0].reg(ni * 8 + 1); // Even + row_0.w = acc[0][0].reg(ni * 8 + 5); // Odd + uint4 row_1; + row_1.x = acc[0][0].reg(ni * 8 + 2); // Even + row_1.y = acc[0][0].reg(ni * 8 + 6); // Odd + row_1.z = acc[0][0].reg(ni * 8 + 3); // Even + row_1.w = acc[0][0].reg(ni * 8 + 7); // Odd + + // Regs_to_rows::extract(acc[mi * MMAS_M_PER_LOOP + mj][ni], row_0, row_1); + + // Each thread of a quad writes 16B per STS -> 64B per store. Account for 2 -> 128B. + int imm_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K + (ni / 2) * 128; + int imm_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K + (ni / 2) * 128; + + // Store the elements. + fmha::sts(this->smem_write_ + imm_0, row_0); + fmha::sts(this->smem_write_ + imm_1, row_1); + } + // Each thread of a quad writes 16B per STS -> 64B per store. + if (Mma_tile::MMAS_N == 1) { + this->smem_write_ ^= 64; + } else { + assert(false && "Unsupported"); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o, Cta_tile> + : public Hmma_smem_tile_o< + Hopper_hgmma_fp16_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_hgmma_fp16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + using Mma_tile = typename Base::Mma_tile; + + using Accumulator = typename Base::Accumulator; + + enum { + LOOPS = Base::LOOPS, + ROW_PACKING = Base::ROW_PACKING, + BYTES_PER_ROW = Base::BYTES_PER_ROW, + }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store the accumulators. + inline __device__ void store(Accumulator const (&acc)[1][1], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::CORES_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + static_assert(MMAS_M_PER_LOOP == 1); + // inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1, + auto smem_write = this->smem_write_ ^ (ni * 16); +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + // Store. + fmha::sts(smem_write + row_0, acc[0][0].reg(ni * 2 + 0)); + fmha::sts(smem_write + row_1, acc[0][0].reg(ni * 2 + 1)); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o, Cta_tile> + : public Hmma_smem_tile_o< + Hopper_hgmma_fp32_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_hgmma_fp32_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + using Mma_tile = typename Base::Mma_tile; + + using Accumulator = typename Base::Accumulator; + + enum { + LOOPS = Base::LOOPS, + ROW_PACKING = Base::ROW_PACKING, + BYTES_PER_ROW = Base::BYTES_PER_ROW, + }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store the accumulators. + inline __device__ void store(Accumulator const (&acc)[1][1], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::CORES_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + static_assert(MMAS_M_PER_LOOP == 1); + // inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1, + auto smem_write = this->smem_write_ ^ (ni * 16); +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + uint32_t val_0 = float2_to_half2(acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 0), + acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 1)); + + uint32_t val_1 = float2_to_half2(acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 2), + acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 3)); + + // Store. + fmha::sts(smem_write + row_0, val_0); + fmha::sts(smem_write + row_1, val_1); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o, Cta_tile> + : public Hmma_smem_tile_o< + Hopper_hgmma_bf16_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_hgmma_bf16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + using Mma_tile = typename Base::Mma_tile; + + using Accumulator = typename Base::Accumulator; + + enum { + LOOPS = Base::LOOPS, + ROW_PACKING = Base::ROW_PACKING, + BYTES_PER_ROW = Base::BYTES_PER_ROW, + }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Convert fp32 to bf16, and store the accumulators. + inline __device__ void store(Accumulator const (&acc)[1][1], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + + static_assert(Mma_tile::CORES_M == 2); + +#pragma unroll + for (int ni = 0; ni < Mma_tile::CORES_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + static_assert(MMAS_M_PER_LOOP == 1); + // inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1, + auto smem_write = this->smem_write_ ^ (ni * 16); +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + uint32_t val_0 = float2_to_bf16_x2(acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 0), + acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 1)); + + uint32_t val_1 = float2_to_bf16_x2(acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 2), + acc[0][0].elt(2 * ni * Mma_tile::CORES_M + 3)); + + // Store. + fmha::sts(smem_write + row_0, val_0); + fmha::sts(smem_write + row_1, val_1); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o, + Cta_tile> + : public Smem_tile_o_gmma_32bit_8bit< + Hopper_qgmma_e4m3_fp32_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_qgmma_e4m3_fp32_traits; + // The base class. + using Base = Smem_tile_o_gmma_32bit_8bit; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +template +struct Smem_tile_o, + Cta_tile> + : public Smem_tile_o_gmma_32bit_8bit< + Hopper_igmma_int8_int32_traits, Cta_tile> { + // The traits class. + using Traits = Hopper_igmma_int8_int32_traits; + // The base class. + using Base = Smem_tile_o_gmma_32bit_8bit; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/tma_descriptor.h b/csrc/fmha_v2/fmha/hopper/tma_descriptor.h new file mode 100644 index 0000000000..22071f3585 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/tma_descriptor.h @@ -0,0 +1,348 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include + +namespace fmha { + +// manage TMA descriptor host code. +// allocate, deallocate and manipulate tma desc in the host +// copy the tma descriptor from host code to device code +// Multiple TMA desc, one desc per batch. +// Device desc ptr should be allocated outside the class and reused +template < + // number of dimensions. + int NUM_DIMS> +class Multiple_tma_descriptor { + public: + // ctor + Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) { + if (batch_size > 0) { + // allocate host memory + desc_ptr_h = new cudaTmaDesc[batch_size]; + // make sure all bit fields are zeros. + memset(desc_ptr_h, 0, sizeof(cudaTmaDesc) * batch_size); + } + } + + // ctor + Multiple_tma_descriptor() = default; + + // destructor. + ~Multiple_tma_descriptor() { + if (batch_size > 0) { + // deallocate host memory + delete[] desc_ptr_h; + } + } + + // set the desctriptor. + int set_tma_desctriptor( + // ptr to gmem + void const* gmem_ptr, + // format is really data_type in TMA terminology. + cudaTmaDescFormat format, + // interleave mode. + cudaTmaDescInterleave interleave, + // swizzle mode. + cudaTmaDescSwizzle swizzle, + // L2 sector promotion. + cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS], + uint64_t const (&tensor_stride_array)[NUM_DIMS - 1], + uint32_t const (&traversal_stride_array)[NUM_DIMS], + uint32_t const (&box_size_array)[NUM_DIMS], + // OOB fill mode. + uint32_t fill_oob, + // FP32 to TF32 conversion. + uint32_t round_to_tf32, + // index to desc. + int batch_idx) { + set_tensor_common_0(&desc_ptr_h[batch_idx], reinterpret_cast(gmem_ptr)); + set_tensor_common_1(&desc_ptr_h[batch_idx], TILED, NUM_DIMS, format, interleave, swizzle, + fill_oob, round_to_tf32, promotion); + + set_tensor_stride(&desc_ptr_h[batch_idx], tensor_stride_array); + set_tensor_size(&desc_ptr_h[batch_idx], tensor_size_array); + + set_traversal_stride_tiled(&desc_ptr_h[batch_idx], traversal_stride_array); + + set_box_size(&desc_ptr_h[batch_idx], box_size_array); + return 0; + } + + // set the desctriptor. + int set_tma_desctriptor( + // ptr to gmem + void const* gmem_ptr, + // format is really data_type in TMA terminology. + cudaTmaDescFormat format, + // interleave mode. + cudaTmaDescInterleave interleave, + // swizzle mode. + cudaTmaDescSwizzle swizzle, + // L2 sector promotion. + cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS], + uint64_t const (&tensor_stride_array)[NUM_DIMS - 1], + uint32_t const (&traversal_stride_array)[NUM_DIMS], + uint32_t const (&box_size_array)[NUM_DIMS], + // OOB fill mode. + uint32_t fill_oob, + // FP32 to TF32 conversion. + uint32_t round_to_tf32, + // index to desc. + cudaTmaDesc* desc_ptr = nullptr) { + set_tensor_common_0(desc_ptr, reinterpret_cast(gmem_ptr)); + set_tensor_common_1(desc_ptr, TILED, NUM_DIMS, format, interleave, swizzle, fill_oob, + round_to_tf32, promotion); + + set_tensor_stride(desc_ptr, tensor_stride_array); + set_tensor_size(desc_ptr, tensor_size_array); + + set_traversal_stride_tiled(desc_ptr, traversal_stride_array); + + set_box_size(desc_ptr, box_size_array); + return 0; + } + + // copy the desc to device memory + void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) { + FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size, + cudaMemcpyHostToDevice)); + } + + // get desc in host + cudaTmaDesc get_desc_in_host(int batch_idx) const { return desc_ptr_h[batch_idx]; } + + private: + void set_tensor_common_0(cudaTmaDesc* p_desc, uint64_t addr) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + desc->tensor_common0 = 0; + desc->tensor_common0 |= (addr); + } + + void set_tensor_common_1(cudaTmaDesc* p_desc, cudaTmaDescType desc_type, uint32_t dims, + cudaTmaDescFormat format, cudaTmaDescInterleave interleave, + cudaTmaDescSwizzle swizzle, uint32_t fill, uint32_t f32_to_tf32, + cudaTmaDescPromotion promotion) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + + desc->tensor_common1 = 0; + desc->tensor_common1 |= desc_type == TILED ? 0x0 : 0x1; + + constexpr uint32_t VERSION_SHIFT = 1; + constexpr uint32_t VERSION_BITS = 3; + desc->tensor_common1 |= (1u << VERSION_SHIFT); + + constexpr uint32_t DIM_BITS = 3; + constexpr uint32_t DIM_SHIFT = VERSION_SHIFT + VERSION_BITS; + constexpr uint32_t DIM_MASK = (1u << DIM_BITS) - 1; + desc->tensor_common1 |= ((dims - 1) & DIM_MASK) << DIM_SHIFT; + + constexpr uint32_t FORMAT_BITS = 4; + constexpr uint32_t FORMAT_SHIFT = DIM_SHIFT + DIM_BITS; + constexpr uint32_t FORMAT_MASK = (1u << FORMAT_BITS) - 1; + desc->tensor_common1 |= (static_cast(format) & FORMAT_MASK) << FORMAT_SHIFT; + + constexpr uint32_t INTERLEAVE_BITS = 2; + constexpr uint32_t INTERLEAVE_SHIFT = FORMAT_SHIFT + FORMAT_BITS; + constexpr uint32_t INTERLEAVE_MASK = (1u << INTERLEAVE_BITS) - 1; + desc->tensor_common1 |= (static_cast(interleave) & INTERLEAVE_MASK) + << INTERLEAVE_SHIFT; + + constexpr uint32_t SWIZZLE_BITS = 2; + constexpr uint32_t SWIZZLE_SHIFT = INTERLEAVE_SHIFT + INTERLEAVE_BITS; + constexpr uint32_t SWIZZLE_MASK = (1u << SWIZZLE_BITS) - 1; + desc->tensor_common1 |= (static_cast(swizzle) & SWIZZLE_MASK) << SWIZZLE_SHIFT; + + constexpr uint32_t FILL_BITS = 1; + constexpr uint32_t FILL_SHIFT = SWIZZLE_SHIFT + SWIZZLE_BITS; + constexpr uint32_t FILL_MASK = (1u << FILL_BITS) - 1; + desc->tensor_common1 |= (static_cast(fill) & FILL_MASK) << FILL_SHIFT; + + constexpr uint32_t F32_TO_TF32_BITS = 1; + constexpr uint32_t F32_TO_TF32_SHIFT = FILL_SHIFT + FILL_BITS; + constexpr uint32_t F32_TO_TF32_MASK = (1u << F32_TO_TF32_BITS) - 1; + desc->tensor_common1 |= (static_cast(f32_to_tf32) & F32_TO_TF32_MASK) + << F32_TO_TF32_SHIFT; + + constexpr uint32_t PROMOTION_BITS = 2; + constexpr uint32_t PROMOTION_SHIFT = F32_TO_TF32_SHIFT + F32_TO_TF32_BITS; + constexpr uint32_t PROMOTION_MASK = (1u << PROMOTION_BITS) - 1; + desc->tensor_common1 |= (static_cast(promotion) & PROMOTION_MASK) << PROMOTION_SHIFT; + } + + // note that tensor stride has 1 less dim. + void set_tensor_stride(cudaTmaDesc* p_desc, uint64_t const (&tensor_stride_array)[NUM_DIMS - 1]) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + + constexpr uint32_t TENSOR_STRIDE_UPPER_BITS = 4; + constexpr uint32_t TENSOR_STRIDE_UPPER_MASK = (1u << TENSOR_STRIDE_UPPER_BITS) - 1; + + for (uint32_t i = 0; i < NUM_DIMS - 1; i++) { + desc->tensor_stride_lower[i] = 0u; + uint64_t tensor_stride_lower_64b = (tensor_stride_array[i] >> 4) & 0xFFFFFFFFlu; + desc->tensor_stride_lower[i] = static_cast(tensor_stride_lower_64b); + } + desc->tensor_stride_upper = 0u; + + for (uint32_t i = 0; i < NUM_DIMS - 1; i++) { + uint64_t tensor_stride_temp = tensor_stride_array[i]; + tensor_stride_temp = tensor_stride_temp >> 4; + uint64_t tensor_stride_upper = tensor_stride_temp >> 32; + uint32_t tensor_stride_upper_32b = static_cast(tensor_stride_upper); + desc->tensor_stride_upper |= + ((tensor_stride_upper_32b & TENSOR_STRIDE_UPPER_MASK) << (i * TENSOR_STRIDE_UPPER_BITS)); + } + } + + void set_tensor_size(cudaTmaDesc* p_desc, uint32_t const (&tensor_size_array)[NUM_DIMS]) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + for (uint32_t dim = 0; dim < NUM_DIMS; dim++) { + desc->tensor_size[dim] = tensor_size_array[dim] - 1; + } + } + + void set_traversal_stride_tiled(cudaTmaDesc* p_desc, + uint32_t const (&traversal_stride_array)[NUM_DIMS]) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + + desc->traversal_stride_box_0 = 0; + + constexpr uint32_t TRAVERSAL_STRIDE_BITS = 3; + constexpr uint32_t TRAVERSAL_STRIDE_MASK = (1u << TRAVERSAL_STRIDE_BITS) - 1; + + for (uint32_t dim = 0; dim < NUM_DIMS; dim++) { + uint32_t traversal_stride = traversal_stride_array[dim] - 1; + traversal_stride = (traversal_stride & TRAVERSAL_STRIDE_MASK) + << (dim * TRAVERSAL_STRIDE_BITS); + desc->traversal_stride_box_0 |= traversal_stride; + } + } + + void set_box_size(cudaTmaDesc* p_desc, uint32_t const (&box_size_array)[NUM_DIMS]) { + cudaTmaDescTiled* desc = reinterpret_cast(p_desc); + + desc->box_size_end = 0; + + constexpr uint32_t BOX_SIZE_BITS = 8; + constexpr uint32_t BOX_SIZE_MASK = (1 << BOX_SIZE_BITS) - 1; + + if (NUM_DIMS > 1) { + uint32_t box_size_0 = box_size_array[0] - 1; + box_size_0 = box_size_0 & BOX_SIZE_MASK; + box_size_0 = box_size_0 << 24; + desc->traversal_stride_box_0 |= box_size_0; + } + + for (uint32_t dim = 1; dim < NUM_DIMS; dim++) { + uint32_t box_size = box_size_array[dim] - 1; + box_size = box_size & BOX_SIZE_MASK; + box_size = box_size << ((dim - 1) * BOX_SIZE_BITS); + desc->box_size_end |= box_size; + } + } + + void set_traversal_stride_im2col(cudaTmaDesc* p_desc, uint32_t* p_traversal_stride, + uint32_t dims) { + cudaTmaDescIm2Col* desc = reinterpret_cast(p_desc); + + desc->traversal_stride_range_c = 0; + + constexpr uint32_t TRAVERSAL_STRIDE_BITS = 3; + constexpr uint32_t TRAVERSAL_STRIDE_MASK = (1u << (TRAVERSAL_STRIDE_BITS + 1)) - 1; + + for (uint32_t dim = 0; dim < dims; dim++) { + uint32_t traversal_stride = p_traversal_stride[dim] - 1; + traversal_stride = (traversal_stride & TRAVERSAL_STRIDE_MASK) + << (dim * TRAVERSAL_STRIDE_BITS); + desc->traversal_stride_range_c |= traversal_stride; + } + } + + void set_range_c(cudaTmaDesc* p_desc, uint32_t range_c) { + cudaTmaDescIm2Col* desc = reinterpret_cast(p_desc); + + constexpr uint32_t RANGE_C_BITS = 8; + constexpr uint32_t RANGE_C_MASK = (1u << RANGE_C_BITS) - 1; + + range_c = range_c & RANGE_C_MASK; + desc->traversal_stride_range_c |= ((range_c - 1) << 24); + } + + void set_box_corner_dhw(cudaTmaDesc* p_desc, uint32_t* p_base_corner, uint32_t* p_far_corner, + uint32_t dims) { + cudaTmaDescIm2Col* desc = reinterpret_cast(p_desc); + + desc->box_corner_dhw = 0; + + uint32_t box_base_corner = 0, box_far_corner = 0; + uint32_t box_corner_dhw = 0; + + if (dims == 3) { + constexpr uint32_t BOX_CORNER_BITS = 16; + constexpr uint32_t BOX_CORNER_MASK = (1u << BOX_CORNER_BITS) - 1; + + box_base_corner = p_base_corner[0] & BOX_CORNER_MASK; + box_far_corner = p_far_corner[0] & BOX_CORNER_MASK; + } + + if (dims == 4) { + constexpr uint32_t BOX_CORNER_BITS = 8; + constexpr uint32_t BOX_CORNER_MASK = (1u << BOX_CORNER_BITS) - 1; + + box_base_corner = p_base_corner[0] & BOX_CORNER_MASK; + box_base_corner |= ((p_base_corner[1] & BOX_CORNER_MASK) << BOX_CORNER_BITS); + + box_far_corner = p_far_corner[0] & BOX_CORNER_MASK; + box_far_corner |= ((p_far_corner[1] & BOX_CORNER_MASK) << BOX_CORNER_BITS); + } + + if (dims == 5) { + constexpr uint32_t BOX_CORNER_BITS = 5; + constexpr uint32_t BOX_CORNER_MASK = (1u << BOX_CORNER_BITS) - 1; + + box_base_corner = p_base_corner[0] & BOX_CORNER_MASK; + box_base_corner |= ((p_base_corner[1] & BOX_CORNER_MASK) << BOX_CORNER_BITS); + box_base_corner |= ((p_base_corner[2] & BOX_CORNER_MASK) << (2 * BOX_CORNER_BITS)); + + box_far_corner = p_far_corner[0] & BOX_CORNER_MASK; + box_far_corner |= ((p_far_corner[1] & BOX_CORNER_MASK) << BOX_CORNER_BITS); + box_far_corner |= ((p_far_corner[2] & BOX_CORNER_MASK) << (2 * BOX_CORNER_BITS)); + } + + box_corner_dhw = box_base_corner; + box_corner_dhw |= (box_far_corner << 16); + + desc->box_corner_dhw = box_corner_dhw; + } + + void set_range_ndhw(cudaTmaDesc* p_desc, uint32_t ndhw) { + cudaTmaDescIm2Col* desc = reinterpret_cast(p_desc); + + desc->range_ndhw = 0; + + constexpr uint32_t RANGE_NDHW_BITS = 10; + constexpr uint32_t RANGE_NDHW_MASK = (1u << RANGE_NDHW_BITS) - 1; + + desc->range_ndhw = ((ndhw - 1) & RANGE_NDHW_MASK); + } + + // The TMA descriptor. Each is of 512 bit. + cudaTmaDesc* desc_ptr_h; + // The TMA descriptor on the device memory. + cudaTmaDesc* desc_ptr_d; + // Number of batches + int batch_size = 0; +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/tma_types.h b/csrc/fmha_v2/fmha/hopper/tma_types.h new file mode 100644 index 0000000000..4f5460ef64 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/tma_types.h @@ -0,0 +1,123 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +namespace fmha { + +// TMA desc type. +typedef enum { TILED = 0, IM2COL } cudaTmaDescType; + +// TMA swizzle type. +typedef enum { + SWIZZLE_DISABLED, + SWIZZLE_32B, + SWIZZLE_64B, + SWIZZLE_128B, + SWIZZLE_MAX +} cudaTmaDescSwizzle; + +typedef enum { BARRIER64, BARRIER128 } cudaTmaDescBarrier; + +// TMA interleave type. +typedef enum { + INTERLEAVE_DISABLED, + INTERLEAVE_16B, + INTERLEAVE_32B, + INTERLEAVE_MAX +} cudaTmaDescInterleave; + +// TMA L2 sector promotion. +typedef enum { + PROMOTION_DISABLED = 0, + PROMOTION_64B, + PROMOTION_128B, + PROMOTION_256B +} cudaTmaDescPromotion; + +// TMA data type. +typedef enum { + U8 = 0, + U16, + U32, + S32, + U64, + S64, + F16_RN, + F32_RN, + F32_FTZ_RN, + F64_RN, + BF16_RN, + FORMAT_MAX +} cudaTmaDescFormat; + +// TMA cache control. +typedef enum { + PREFETCH, // Prefetch tma descriptor using global memory address + INVALIDATE, // Invalidate tma descriptor in l2 cache + INVALIDATE_ALL // Invalidate tma descriptor and all elements in l2 cache line +} cudaTmaDescCacheCtrl; + +// TMA OOB fill modes. +typedef enum { TENSOR_ZFILL, TENSOR_CFILL } cudaTmaDescOobFillMode; + +constexpr uint64_t k_max_tensor_size = (1llu << 36); +constexpr uint64_t k_max_tensor_stride = (1llu << 36); +constexpr uint64_t k_max_block_size = 256llu; +constexpr uint64_t k_max_traversal_stride = (1llu << 3); + +constexpr uint64_t k_min_tensor_size = 1llu; +constexpr uint64_t k_min_tensor_stride = 0llu; +constexpr uint64_t k_min_block_size = 1llu; +constexpr uint64_t k_min_traversal_stride = 1llu; + +constexpr uint32_t k_max_cta_id = (1 << 6) - 1; + +// The 512 bit of descriptor for tiled mode. +typedef struct { + uint64_t tensor_common0; + uint32_t tensor_common1; + + uint32_t tensor_stride_lower[4]; //< 36b of 64b with 4B aligned + uint32_t tensor_stride_upper; + uint32_t tensor_size[5]; //< value -1 + uint32_t traversal_stride_box_0; //< packed 3b (-1) + + uint32_t box_size_end; +} cudaTmaDescTiled; + +// The 512 bit of descritptro for im2col mode. +typedef struct { + uint64_t tensor_common0; + uint32_t tensor_common1; + + uint32_t tensor_stride_lower[4]; + uint32_t tensor_stride_upper; + uint32_t tensor_size[5]; + uint32_t traversal_stride_range_c; + + uint32_t box_corner_dhw; + uint32_t range_ndhw; +} cudaTmaDescIm2Col; + +// TMA desc size +constexpr uint32_t TMA_DESC_SIZE_IN_BYTE = 64; + +// TMA desc +typedef struct alignas(64) { + uint64_t data[8]; +} cudaTmaDesc; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_gmma.h b/csrc/fmha_v2/fmha/hopper/utils_gmma.h new file mode 100644 index 0000000000..cc070be7de --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_gmma.h @@ -0,0 +1,18 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include diff --git a/csrc/fmha_v2/fmha/hopper/utils_hgmma.h b/csrc/fmha_v2/fmha/hopper/utils_hgmma.h new file mode 100644 index 0000000000..5112317228 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_hgmma.h @@ -0,0 +1,874 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// GMMAs with fp16 Accumulator +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<8, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[2]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16\n" + "{\n" + " %0, %1\n" + "}, %2, %3, 1, 1, 1, %4, %5;\n" + + : "+r"(acc[0]), "+r"(acc[1]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<32, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[8]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7 \n" + "},\n" + " %8, %9, 1, 1, 1, %10, %11;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<64, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + " %16, %17, 1, 1, 1, %18, %19;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<128, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1, %34, %35;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<192, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[48]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47 \n" + "},\n" + " %48, %49, 1, 1, 1, %50, %51;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp16<256, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1, %66, %67;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_fp16(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[N / 4]) { + Hgmma_fp16::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// GMMAs with fp32 Accumulator +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<8, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3}, %4, %5, 1, 1, 1, %6, %7;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<64, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1, %34, %35;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<128, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1, %66, %67;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<192, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1, %98, %99;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_fp32<256, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1, %130, %131;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_fp32(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[N / 2]) { + Hgmma_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// GMMAs with fp16 Accumulator, where A is coming from RF +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<8, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[2]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n" + + : "+r"(acc[0]), "+r"(acc[1]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x16x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<16, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{ %0, %1, %2, %3 },\n" + "{ %4, %5, %6, %7 }, %8, 1, 1, 1, %9;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<32, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[8]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{ %0, %1, %2, %3, %4, %5, %6, %7 },\n" + "{ %8, %9, %10, %11 }, %12, 1, 1, 1, %13;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<64, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1, %21;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<128, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1, %37;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<192, TB> { + static inline __device__ void mma(const uint32_t (&a)[4], uint64_t desc_b, uint32_t (&acc)[48]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47 \n" + "},\n" + "{ %48, %49, %50, %51 }, %52, 1, 1, 1, %53;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp16<256, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1, %69;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_rfa_fp16(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 4]) { + Hgmma_rfa_fp16::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// GMMAs with fp32 Accumulator, where A is coming from RF +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<8, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3\n" + "}\n," + "{ %4, %5, %6, %7 }, %8, 1, 1, 1, %9;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<32, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1, %21;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<64, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1, %37;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<128, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1, %69;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<192, TB> { + static inline __device__ void mma(const uint32_t (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1, %101;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_fp32<256, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1, %133;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_rfa_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Hgmma_rfa_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h b/csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h new file mode 100644 index 0000000000..7b17b508bb --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h @@ -0,0 +1,475 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// BF16 GMMAs with FP32 Accumulator +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<8, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3}, %4, %5, 1, 1, 1, %6, %7;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<64, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1, %34, %35;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<128, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1, %66, %67;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<192, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1, %98, %99;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_bf16<256, TA, TB> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_a = TA ? 1 : 0; + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1, %130, %131;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b), "n"(trans_a), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_bf16(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[N / 2]) { + Hgmma_bf16::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// BF16 GMMAs with FP32 Accumulator, where A is coming from RF +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<8, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3\n" + "}\n," + "{ %4, %5, %6, %7 }, %8, 1, 1, 1, %9;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<32, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1, %21;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<64, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1, %37;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<128, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1, %69;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<192, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1, %101;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hgmma_rfa_bf16<256, TB> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + int const trans_b = TB ? 1 : 0; + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1, %133;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void hgmma_rfa_bf16(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Hgmma_rfa_bf16::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_igmma.h b/csrc/fmha_v2/fmha/hopper/utils_igmma.h new file mode 100644 index 0000000000..fcced80616 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_igmma.h @@ -0,0 +1,396 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// IGMMA 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Igmma_int8_int32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_int8_int32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_int8_int32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_int8_int32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_int8_int32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void igmma_int8_int32(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[N / 2]) { + Igmma_int8_int32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// IGMMA 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Igmma_rfa_int8_int32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_rfa_int8_int32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_rfa_int8_int32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_rfa_int8_int32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Igmma_rfa_int8_int32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void igmma_rfa_int8_int32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Igmma_rfa_int8_int32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_qgmma.h b/csrc/fmha_v2/fmha/hopper/utils_qgmma.h new file mode 100644 index 0000000000..28571b15b9 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_qgmma.h @@ -0,0 +1,2089 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_e4m3_e4m3_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<32> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15\n" + "},\n" + " %16, %17, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e4m3_fp32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_e4m3_e4m3_fp32(uint64_t desc_a, uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_e4m3_e4m3_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_rfa_e4m3_e4m3_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<32> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15\n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e4m3_fp32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_rfa_e4m3_e4m3_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_rfa_e4m3_e4m3_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e4m3 x e5m2 - 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_e4m3_e5m2_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<8> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + " %8, %9, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<32> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + " %16, %17, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<160> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + " %80, %81, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e4m3_e5m2_fp32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_e4m3_e5m2_fp32(uint64_t desc_a, uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_e4m3_e5m2_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e4m3 x e5m2 - 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_rfa_e4m3_e5m2_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<8> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + "{ %8, %9, %10, %11 }, %12, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<32> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<160> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + "{ %80, %81, %82, %83 }, %84, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e4m3_e5m2_fp32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_rfa_e4m3_e5m2_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_rfa_e4m3_e5m2_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e5m2 x e4m3 - 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_e5m2_e4m3_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<8> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + " %8, %9, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<32> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + " %16, %17, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<160> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + " %80, %81, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e4m3_fp32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_e5m2_e4m3_fp32(uint64_t desc_a, uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_e5m2_e4m3_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e5m2 x e4m3 - 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_rfa_e5m2_e4m3_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<8> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + "{ %8, %9, %10, %11 }, %12, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<32> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<160> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + "{ %80, %81, %82, %83 }, %84, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e4m3_fp32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_rfa_e5m2_e4m3_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_rfa_e5m2_e4m3_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e5m2 x e5m2 - 64xNx32 TN with int32 Accumulator with A and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_e5m2_e5m2_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<8> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + " %8, %9, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<64> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + " %32, %33, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<128> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + " %64, %65, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<160> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + " %80, %81, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<192> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + " %96, %97, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_e5m2_e5m2_fp32<256> { + static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + " %128, %129, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "l"(desc_a), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_e5m2_e5m2_fp32(uint64_t desc_a, uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_e5m2_e5m2_fp32::mma(desc_a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// QGMMA e5m2 x e5m2 - 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qgmma_rfa_e5m2_e5m2_fp32 {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x8x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<8> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7\n" + "},\n" + "{ %8, %9, %10, %11 }, %12, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x32x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<32> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[16]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15 \n" + "},\n" + "{ %16, %17, %18, %19 }, %20, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x64x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<64> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31 \n" + "},\n" + "{ %32, %33, %34, %35 }, %36, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x128x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<128> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63 \n" + "},\n" + "{ %64, %65, %66, %67 }, %68, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x160x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<160> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[80]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79 \n" + "},\n" + "{ %80, %81, %82, %83 }, %84, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x192x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<192> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95 \n" + "},\n" + "{ %96, %97, %98, %99 }, %100, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// 64x256x32 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Qgmma_rfa_e5m2_e5m2_fp32<256> { + static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2\n" + "{\n" + " %0, %1, %2, %3, %4, %5, %6, %7,\n" + " %8, %9, %10, %11, %12, %13, %14, %15,\n" + " %16, %17, %18, %19, %20, %21, %22, %23,\n" + " %24, %25, %26, %27, %28, %29, %30, %31,\n" + " %32, %33, %34, %35, %36, %37, %38, %39,\n" + " %40, %41, %42, %43, %44, %45, %46, %47,\n" + " %48, %49, %50, %51, %52, %53, %54, %55,\n" + " %56, %57, %58, %59, %60, %61, %62, %63,\n" + " %64, %65, %66, %67, %68, %69, %70, %71,\n" + " %72, %73, %74, %75, %76, %77, %78, %79,\n" + " %80, %81, %82, %83, %84, %85, %86, %87,\n" + " %88, %89, %90, %91, %92, %93, %94, %95,\n" + " %96, %97, %98, %99, %100, %101, %102, %103,\n" + " %104, %105, %106, %107, %108, %109, %110, %111,\n" + " %112, %113, %114, %115, %116, %117, %118, %119,\n" + " %120, %121, %122, %123, %124, %125, %126, %127 \n" + "},\n" + "{ %128, %129, %130, %131 }, %132, 1, 1, 1;\n" + + : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]), + "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]), + "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]), + "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]), + "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]), + "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]), + "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]), + "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]), + "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]), + "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]), + "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]), + "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]), + "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]), + "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]), + "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]), + "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]), + "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]), + "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]), + "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]), + "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]), + "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]), + "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]), + "+r"(acc[126]), "+r"(acc[127]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b)); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void qgmma_rfa_e5m2_e5m2_fp32(uint32_t const (&a)[4], uint64_t desc_b, + uint32_t (&acc)[N / 2]) { + Qgmma_rfa_e5m2_e5m2_fp32::mma(a, desc_b, acc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_tma.h b/csrc/fmha_v2/fmha/hopper/utils_tma.h new file mode 100644 index 0000000000..faa63edb81 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_tma.h @@ -0,0 +1,155 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +namespace fmha { + +inline __device__ uint32_t elect_one_sync(); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void utmaldg(cudaTmaDesc const* p_desc, // TMA desc + uint32_t smem_ptr, // desc smem address + uint32_t smem_barrier, // smem_barrier + int32_t const (&coord)[DIM], // coord + uint32_t elect_one = 1); + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTMALDG TILED WITHOUT MULTICAST +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void utmaldg<2, fmha::cudaTmaDescType::TILED, false>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + uint32_t smem_barrier, + int32_t const (&coord)[2], + uint32_t elect_one) { + if (elect_one) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3}], [%4];\n" + : + : "r"(smem_ptr), "l"(reinterpret_cast(p_desc)), "r"(coord[0]), "r"(coord[1]), + "r"(smem_barrier) + : "memory"); +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ void utmaldg<3, fmha::cudaTmaDescType::TILED, false>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + uint32_t smem_barrier, + int32_t const (&coord)[3], + uint32_t elect_one) { + if (elect_one) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3, %4}], [%5];\n" + : + : "r"(smem_ptr), "l"(reinterpret_cast(p_desc)), "r"(coord[0]), "r"(coord[1]), + "r"(coord[2]), "r"(smem_barrier) + : "memory"); +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 4D, TILED, without Multicast +template <> +inline __device__ void utmaldg<4, fmha::cudaTmaDescType::TILED, false>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + uint32_t smem_barrier, + int32_t const (&coord)[4], + uint32_t elect_one) { + if (elect_one) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%2, %3, %4, %5}], [%6];\n" + : + : "r"(smem_ptr), "l"(reinterpret_cast(p_desc)), "r"(coord[0]), "r"(coord[1]), + "r"(coord[2]), "r"(coord[3]), "r"(smem_barrier) + : "memory"); +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTMASTG TILED WITHOUT MULTICAST +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void utmastg(cudaTmaDesc const* p_desc, // TMA desc + uint32_t smem_ptr, // src smem address + int32_t const (&coord)[DIM]); // coord + +// 3D, TILED +template <> +inline __device__ void utmastg<3, fmha::cudaTmaDescType::TILED>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + const int32_t (&coord)[3]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%1, %2, %3}], [%4];\n" ::"l"( + reinterpret_cast(p_desc)), + "r"(coord[0]), "r"(coord[1]), "r"(coord[2]), "r"(smem_ptr) + : "memory"); +#endif +} + +// 4D, TILED +template <> +inline __device__ void utmastg<4, fmha::cudaTmaDescType::TILED>(cudaTmaDesc const* p_desc, + uint32_t smem_ptr, + int32_t const (&coord)[4]) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%1, %2, %3, %4}], [%5];\n" ::"l"( + reinterpret_cast(p_desc)), + "r"(coord[0]), "r"(coord[1]), "r"(coord[2]), "r"(coord[3]), "r"(smem_ptr) + : "memory"); +#endif +} + +inline __device__ void tmastg_arrive() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("cp.async.bulk.commit_group;"); +#else + assert(false); +#endif +} + +inline __device__ void tmastg_wait() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory"); +#else + assert(false); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/hopper/utils_warpgroup.h b/csrc/fmha_v2/fmha/hopper/utils_warpgroup.h new file mode 100644 index 0000000000..8923316f61 --- /dev/null +++ b/csrc/fmha_v2/fmha/hopper/utils_warpgroup.h @@ -0,0 +1,44 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void warpgroup_arrive() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile("wgmma.fence.sync.aligned;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void warpgroup_commit() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile("wgmma.commit_group.sync.aligned;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void warpgroup_wait() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) + asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/kernel_traits.h b/csrc/fmha_v2/fmha/kernel_traits.h new file mode 100644 index 0000000000..8e1d5cbb22 --- /dev/null +++ b/csrc/fmha_v2/fmha/kernel_traits.h @@ -0,0 +1,879 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Ada hmma/imma reuses Ampere +template +struct Traits_reuse { + using Traits = Traits_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Traits_reuse { + using Traits = fmha::Ampere_hmma_fp16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Traits_reuse { + using Traits = fmha::Ampere_hmma_fp32_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Traits_reuse { + using Traits = fmha::Ampere_imma_int8_int32_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Traits_o_adapter { + using Traits = Traits_p; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Traits_o_adapter { + using Traits = fmha::Volta_hmma_fp16_16x16x16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// convert to fp16 before smem_o store +template <> +struct Traits_o_adapter { + using Traits = fmha::Ampere_hmma_fp16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// convert to fp16 before smem_o store +template <> +struct Traits_o_adapter { + using Traits = fmha::Turing_hmma_fp16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// convert to bf16 before smem_o store +template <> +struct Traits_o_adapter { + using Traits = fmha::Ampere_hmma_bf16_bf16_traits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // Instruction traits. + typename Traits_, + // The global memory tile for Q, K and V. + template class Gmem_tile_q_, + template class Gmem_tile_k_, + template class Gmem_tile_v_, + // The global memory tile for the output. + template class Gmem_tile_o_, + // Sequence length. + int S, + // The valid hidden dimension. + int VALID_D_, + // The valid hidden dimension of V. + int VALID_DV_, + // The iteration step of the outer loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD_, + // The flags to control the behaviour of LDGs. + uint32_t FLAGS, + // The version of the kernel. + int VERSION_, + // The mask version of the kernel + int MASK_VERSION_, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // non-positive means disabled + int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0> +struct Kernel_traits_ { + // The instruction traits for the Q*K product. + using Traits_p = typename Traits_reuse::Traits; + // The instruction traits for the P*V product. Hack to change the traits for Volta HMMA. + using Traits_o = typename Traits_o_adapter::Traits; + // The instruction traits for the epilogue of the 2nd GEMM. Always use FP16. + using Traits_e = typename Traits_o_adapter::Traits; + + // The padded D dimension + enum { VALID_D = VALID_D_ }; + + enum { D = Next_power_of_two::VALUE }; + + enum { VALID_DV = VALID_DV_ > 0 ? VALID_DV_ : VALID_D }; + + enum { DV = Next_power_of_two::VALUE }; + + enum { + SAGE_ATTENTION = SAGE_BLOCK_SIZE_Q_ > 0 || SAGE_BLOCK_SIZE_K_ > 0 || SAGE_BLOCK_SIZE_V_ > 0 + }; + + enum { SAGE_BLOCK_SIZE_Q = SAGE_BLOCK_SIZE_Q_ }; + + enum { SAGE_BLOCK_SIZE_K = SAGE_BLOCK_SIZE_K_ }; + + enum { SAGE_BLOCK_SIZE_V = SAGE_BLOCK_SIZE_V_ }; + + // TODO: expose these tiling params to the interface + enum { USE_GRANULAR_TILING = (FLAGS & 0x1000) != 0u }; // TODO ANT: check FLAGS + + using Traits_tile_size = + Traits_tile_size<(bool)USE_GRANULAR_TILING, STEP, S, D, DV, Traits_o::K_PER_MMA>; + + enum { CTA_P_TILE_M = Traits_tile_size::CTA_P_TILE_M }; + + enum { CTA_P_TILE_N = Traits_tile_size::CTA_P_TILE_N }; + + enum { CTA_P_TILE_K = Traits_tile_size::CTA_P_TILE_K }; + + enum { CTA_O_TILE_M = Traits_tile_size::CTA_O_TILE_M }; + + enum { CTA_O_TILE_N = Traits_tile_size::CTA_O_TILE_N }; + + enum { CTA_O_TILE_K = Traits_tile_size::CTA_O_TILE_K }; + + // Do we need to reload Q due to splitting the D ? + enum { RELOAD_Q = static_cast(CTA_P_TILE_K) != static_cast(D) }; + + // The CTA description for the 1st GEMM. + using Cta_tile_p = + typename Traits_p::template Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = + typename Traits_o::template Cta_tile_extd; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = typename Traits_p::template Mma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = typename Traits_o::template Mma_tile; + + // Compute the total BMM2_MMAS_K (might not the same as Mma_tile_o::MMAS_K if the granular tiling + // is used). + static_assert(S % CTA_O_TILE_K == 0, ""); + + enum { TOTAL_BMM2_MMAS_K = Mma_tile_o::MMAS_K * (S / CTA_O_TILE_K) }; + + // Constraints on the K dimension. + static_assert(Mma_tile_p::K_PER_MMA <= static_cast(D)); + static_assert(Mma_tile_o::K_PER_MMA <= S); + + // The version. + enum { VERSION = VERSION_ }; + + // The mask version: padding (2), causal (3), sliding_window_causal (4), custom_mask (5). + enum { MASK_VERSION = MASK_VERSION_ }; + + // Whether use causal mask or not. + enum { CAUSAL_MASK = MASK_VERSION_ == 3 || MASK_VERSION_ == 4 }; + + // Whether use the sliding window attention or not. + enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4 }; + + // Whether use the custom mask or not. + enum { CUSTOM_MASK = MASK_VERSION_ == 5 }; + + // Do we use LDGSTS for Q, K or V. + enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; + + enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; + + enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; + + // Do we use one buffer for K and V. + enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; + + // Do we use the scale max trick. + enum { USE_SCALE_MAX = (FLAGS & 0x10u) != 0u }; + + // Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d. + enum { HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u }; + + // Keep full K matrix in registers. + enum { K_IN_REGS = (FLAGS & 0x40) == 0u }; + + // Do we use only 2 fragments or full fragments for frag_q/k (only used by flash attention) + enum { LIMIT_QK_FRAGMENTS = ((FLAGS & 0x80u) != 0u && !SHARE_SMEM_FOR_K_AND_V) }; + + // Do we use only 2 fragments or full fragments for frag_v (only used by flash attention) + enum { LIMIT_V_FRAGMENTS = ((FLAGS & 0x100u) != 0u && !SHARE_SMEM_FOR_K_AND_V) }; + + // Limiting QK fragments implies SMEM_K has to reside in SMEM + static_assert(!(LIMIT_QK_FRAGMENTS && SHARE_SMEM_FOR_K_AND_V), ""); + + // Indicates that kernel does not loop over Q tensor, usually kernel name has _nl suffix + enum { NO_LOOP = (FLAGS & 0x200u) != 0u }; + + // Are sequences in one batch interleaved. i.e. s x b x ..., or b x s x ... + enum { SEQUENCES_INTERLEAVED = (FLAGS & 0x400) != 0u }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = (FLAGS & 0x800) != 0u }; + + // Use MTP (multi-token prediction for MLA kernels) or not. + enum { IS_MTP = (FLAGS & 0x2000) != 0u }; + + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + enum { CTAS_PER_HEAD = CTAS_PER_HEAD_ }; + + // The number of shared memory buffers to build a software pipeline for Q, K and V. + enum { + BUFFERS_PER_TILE_SMEM_Q = (USE_GRANULAR_TILING && D > 64) || (USE_LDGSTS_Q && !NO_LOOP) ? 2 : 1 + }; + + enum { BUFFERS_PER_TILE_SMEM_K = USE_GRANULAR_TILING ? 2 : 1 }; + + enum { BUFFERS_PER_TILE_SMEM_V = USE_GRANULAR_TILING ? 2 : 1 }; + + // The global memory tile to load Q. + using Gmem_tile_q = Gmem_tile_q_; + + // The shared memory tile to swizzle Q. + using Smem_tile_q = fmha::Smem_tile_a; + + // The global memory tile to load K. + using Gmem_tile_k = Gmem_tile_k_; + + // The shared memory tile to swizzle K. + using Smem_tile_k = fmha::Smem_tile_b; + + // The global memory tile to load V. + using Gmem_tile_v = Gmem_tile_v_; + + // The shared memory tile to swizzle V. + using Smem_tile_v = fmha::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = Gmem_tile_o_; + // The shared memory tile for O. + using Smem_tile_o = fmha::Smem_tile_o; + + // Make sure the number of threads match. + static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); + + // The number of threads. + enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; + + // Make sure the number of threads matches both CTAs. + static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); + + // The amount of shared memory needed to load Q and K. + enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; + + // The extra amount of shared memory needed to load V. + enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K and V.. + enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; + + // The amount of shared memory needed to load/store O. + enum { BYTES_PER_SMEM_O = Smem_tile_o::BYTES_PER_TILE }; + + // The amount of shared memory needed to load Q and store O. + enum { + BYTES_PER_SMEM_QO = + NO_LOOP ? Smem_tile_o::BYTES_PER_TILE : Smem_tile_q::BYTES_PER_TILE + BYTES_PER_SMEM_O + }; + + // The amount of shared memory needed for Q, K, V and O. + enum { BYTES_PER_SMEM = fmha::Max::VALUE }; + + // Make sure we have enough shared memory. + static_assert((NO_LOOP + ? Smem_tile_o::BYTES_PER_TILE + : Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE) <= BYTES_PER_SMEM, + ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // Instruction traits. + typename Traits_, + // The global memory tile for Q, K and V. + template class Gmem_tile_q_, + // The global memory tile for the output. + template class Gmem_tile_o_, + // Sequence length for K/V. + int S_KV, + // The hidden dimension. + int D, + // The iteration step of the outer loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD_, + // The flags to control the behaviour of LDGs. + uint32_t FLAGS, + // The version of the kernel. + int VERSION_, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true> +struct Kernel_traits_fmhca_ { + // The instruction traits for the Q*K product. + using Traits_p = typename Traits_reuse::Traits; + // The instruction traits for the P*V product. Hack to change the traits for Volta HMMA. + using Traits_o = typename Traits_o_adapter::Traits; + // The instruction traits for the epilogue of the 2nd GEMM. Always use FP16. + using Traits_e = typename Traits_o_adapter::Traits; + + // The CTA description for the 1st GEMM. + using Cta_tile_p = + typename Traits_p::template Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = + typename Traits_o::template Cta_tile_extd; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = typename Traits_p::template Mma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = typename Traits_o::template Mma_tile; + + // Constraints on the K dimension. + static_assert(Mma_tile_p::K_PER_MMA <= D, ""); + static_assert(Mma_tile_o::K_PER_MMA <= S_KV, ""); + + // The version. + enum { VERSION = VERSION_ }; + + // The mask version + enum { MASK_VERSION = VERSION_ }; + + // Whether use causal mask or not. + enum { CAUSAL_MASK = MASK_VERSION >= 3 }; + + // Whether use the sliding window attention or not. + enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION == 4 }; + + // Do we use LDGSTS for Q, K or V. + enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; + + enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; + + enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; + + // Do we use one buffer for K and V. + enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; + + // Do we use the scale max trick. + enum { USE_SCALE_MAX = (FLAGS & 0x10u) != 0u }; + + // Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d. + enum { HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u }; + + // Keep full K matrix in registers. + enum { K_IN_REGS = (FLAGS & 0x40) == 0u }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = 0 }; + + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + enum { CTAS_PER_HEAD = CTAS_PER_HEAD_ }; + + // The global memory tile to load Q. + using Gmem_tile_q = Gmem_tile_q_; + + // The shared memory tile to swizzle Q. + using Smem_tile_q = fmha::Smem_tile_a; + + // The global memory tile to load K. + using Gmem_tile_k = Gmem_tile_q_; + + // The shared memory tile to swizzle K. + using Smem_tile_k = fmha::Smem_tile_b; + + // The global memory tile to load V. + using Gmem_tile_v = Gmem_tile_q_; + + // The shared memory tile to swizzle V. + using Smem_tile_v = fmha::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = Gmem_tile_o_; + // The shared memory tile for O. + using Smem_tile_o = fmha::Smem_tile_o; + + // Make sure the number of threads match. + static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); + + // The number of threads. + enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; + + // Make sure the number of threads matches both CTAs. + static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); + + // The amount of shared memory needed to load Q and K. + enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; + + // The extra amount of shared memory needed to load V. + enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K and V.. + enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; + + // The amount of shared memory needed to load Q and store O. + enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K, V and O. + enum { BYTES_PER_SMEM = fmha::Max::VALUE }; + + // Make sure we have enough shared memory. + static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits_, + // The sequence length. + int S, + // The hidden size per head. + int VALID_D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD_, + // The flags. + uint32_t FLAGS = 0x8, + // The mask version of the kernel + int MASK_VERSION_ = 2> +struct Kernel_traits_interleaved_v2_ { + // The instruction traits. + using Traits = typename Traits_reuse::Traits; + using Traits_p = Traits; + using Traits_o = Traits; + + // The padded D dimension + enum { D = Next_power_of_two::VALUE }; + + // The CTA description for the 1st GEMM. + using Cta_tile_p = + typename Traits::template Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = + typename Traits::template Cta_tile_extd; + + // The version. + enum { VERSION = 2 }; + + enum { MASK_VERSION = MASK_VERSION_ }; + + // Whether use causal mask or not. + enum { CAUSAL_MASK = MASK_VERSION_ >= 3 }; + + // Whether use the sliding window attention or not. + enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4 }; + + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + enum { CTAS_PER_HEAD = CTAS_PER_HEAD_ }; + + // Do we use LDGSTS for Q, K or V. + enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; + + enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; + + enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; + + // Do we use one buffer for K and V. + enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; + + // Do we use the scale max trick. + enum { USE_SCALE_MAX = (FLAGS & 16) != 0u }; + + // The global memory tile to load Q. + using Gmem_tile_q = + fmha::v2::Gmem_tile_qkv_interleaved; + // The shared memory tile to swizzle Q. + using Smem_tile_q = fmha::Smem_tile_qk_interleaved_a; + + // The global memory tile to load K. + using Gmem_tile_k = + fmha::v2::Gmem_tile_qkv_interleaved; + // The shared memory tile to swizzle K. + using Smem_tile_k = fmha::Smem_tile_qk_interleaved_b; + + // The global memory tile to load V. + using Gmem_tile_v = + fmha::v2::Gmem_tile_qkv_interleaved; + + // The shared memory tile to swizzle V. + using Smem_tile_v = fmha::Smem_tile_v_interleaved_b; + + // The global memory tile to store O. + using Gmem_tile_o = fmha::v2::Imma_gmem_tile_o_interleaved; + // The shared memory tile for O. + using Smem_tile_o = fmha::Smem_tile_o_interleaved; + + // Make sure the number of threads match. + static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); + + // The number of threads. + enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; + + // Make sure the number of threads matches both CTAs. + static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); + + // The amount of shared memory needed to load Q and K. + enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; + + // The extra amount of shared memory needed to load V. + enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K and V.. + enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; + + // The amount of shared memory needed to load Q and store O. + enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; + + // The amount of shared memory needed for Q, K, V and O. + enum { BYTES_PER_SMEM = fmha::Max::VALUE }; + + // Make sure we have enough shared memory. + static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits_, + // The sequence length. + int S, + // The hidden size per head. + int VALID_D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD_, + // The flags. + uint32_t FLAGS = 0x8, + // The mask version of the kernel + int MASK_VERSION_ = 2> +using Kernel_traits_interleaved_v2 = + Kernel_traits_interleaved_v2_; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8> +using Kernel_traits_v1 = Kernel_traits_; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8> +using Kernel_traits_v1_causal_mask = + Kernel_traits_; // MASK_VERSION_ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gmem_tile_o_dispatcher { + template + using Gmem_tile_o = fmha::v2::Gmem_tile_o; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Gmem_tile_o_dispatcher { + template + using Gmem_tile_o = fmha::v2::Gmem_tile_o_uint16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Gmem_tile_o_dispatcher { + template + using Gmem_tile_o = fmha::v2::Gmem_tile_o_bfloat16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The hidden dimension of V. + int DV, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8, + // The attention mask version (see src/mask.h). + int MASK_VERSION = 2, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // The output type. + typename OutputType = typename Traits::A_type, + // The sage attention block size for Q, K and V + int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> +using Kernel_traits_v2 = + Kernel_traits_::Gmem_tile_o, + S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, + BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The hidden dimension of V. + int DV, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8, + // The attention mask version (see src/mask.h). + int MASK_VERSION = 2, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // The output type. + typename OutputType = typename Traits::A_type, + // The sage attention block size for Q, K and V + int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> +using Kernel_traits_v2_q_k_v = + Kernel_traits_::Gmem_tile_o, S, D, DV, STEP, WARPS_M, + WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, + SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The hidden dimension of V. + int DV, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8, + // The attention mask version (see src/mask.h). + int MASK_VERSION = 2, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // The output type. + typename OutputType = typename Traits::A_type, + // The sage attention block size for Q, K and V + int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> +using Kernel_traits_v2_paged_kv_cache = + Kernel_traits_::Gmem_tile_o, S, D, DV, STEP, WARPS_M, + WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, + SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length. + int S, + // The hidden size per head. + int D, + // The hidden dimension of V. + int DV, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8, + // The attention mask version (see src/mask.h). + int MASK_VERSION = 2, + // Do we use half epilogue for the 2nd GEMM (hmma_fp32) + bool BMM2_FP16_EPILOGUE = true, + // The output type. + typename OutputType = typename Traits::A_type, + // The sage attention block size for Q, K and V + int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> +using Kernel_traits_v2_contiguous_kv_cache = + Kernel_traits_::Gmem_tile_o, S, D, 0, STEP, WARPS_M, + WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, + SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The sequence length for K and V. + int S_KV, + // The hidden size per head. + int D, + // The number of timesteps per iteration of the main loop. + int STEP, + // The number of vertical warps. + int WARPS_M, + // The number of horizontal warps. + int WARPS_N, + // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K + int CTAS_PER_HEAD, + // The flags. + uint32_t FLAGS = 0x8> +using Kernel_traits_fmhca = + Kernel_traits_fmhca_; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/mask.h b/csrc/fmha_v2/fmha/mask.h new file mode 100644 index 0000000000..3219947ccf --- /dev/null +++ b/csrc/fmha_v2/fmha/mask.h @@ -0,0 +1,785 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "fmha/traits.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask { + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in each dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) { + // The pointer. + packed_mask_ptr_ = reinterpret_cast(params.packed_mask_ptr); + // Take the head into account. + packed_mask_ptr_ += block_info.bidb * params.packed_mask_stride_in_bytes; + // The thread inside the CTA. + packed_mask_ptr_ += tidx * sizeof(uint32_t); + } + + // Load the mask into registers (and expand). + inline __device__ void load(int it) { + // One 32-bit integer per MMA. + uint32_t packed_mask[MMAS_M]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + int offset = (it * MMAS_M + mi) * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t); + fmha::ldg(packed_mask[mi], packed_mask_ptr_ + offset); + } + +// Expand the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + mask_[2 * mi + 0][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 0)); + mask_[2 * mi + 0][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 1)); + mask_[2 * mi + 1][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 2)); + mask_[2 * mi + 1][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 3)); + mask_[2 * mi + 0][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 4)); + mask_[2 * mi + 0][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 5)); + mask_[2 * mi + 1][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 6)); + mask_[2 * mi + 1][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 7)); + } + } + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + return mask_[mi * 2 + ii][ni * 4 + jj]; + } + + // The pointer to the mask. + char const* packed_mask_ptr_; + // The mask after expansion. + bool mask_[MMAS_M * 2][MMAS_N * 4]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask { + // The instruction traits. + using Traits = Volta_hmma_fp16_traits; + // The shape of the MMA tile. + using Mma_tile = typename Traits::Mma_tile; + + // The number of MMAs in each dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) { + // The pointer. + packed_mask_ptr_ = reinterpret_cast(params.packed_mask_ptr); + // Take the head into account. + packed_mask_ptr_ += block_info.bidb * params.packed_mask_stride_in_bytes; + // The thread inside the CTA. + packed_mask_ptr_ += tidx * sizeof(uint32_t); + } + + // Load the mask into registers (and expand). + inline __device__ void load(int it) { + // One 32-bit integer per MMA. + uint32_t packed_mask[MMAS_M]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + int offset = (it * MMAS_M + mi) * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t); + fmha::ldg(packed_mask[mi], packed_mask_ptr_ + offset); + } + +// Expand the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < MMAS_N * 8; ++ii) { + mask_[mi][ii] = packed_mask[mi] & (1u << ii); + } + } + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int, int jj) const { + return mask_[mi][ni * 8 + jj]; + } + + // The pointer to the mask. + char const* packed_mask_ptr_; + // The mask after expansion. + bool mask_[MMAS_M][MMAS_N * 8]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask { + // That implementation works only when WARPS_K is 1. + static_assert(Cta_tile::WARPS_K == 1, ""); + + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : seqlen_(block_info.actual_seqlen), col_loop_step_(0) { + // The decomposition of the thread index into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_n = warp / Cta_tile::WARPS_M; + // The position of the thread. + col_ = block_info.bidn * Cta_tile::N + warp_n * 16 + lane % 4 * 2; + col_init_ = col_; + } + + // Is a given position valid? + inline __device__ bool is_valid(int, int ni, int, int jj) const { + // The position of the thread in the sequence. + int offset = this->col_ + this->col_loop_step_ * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + offset += (jj & 0x02) * 4 + (jj & 0x1); + // Is it a valid position in the sequence? + return offset < seqlen_; + } + + // BERT Mask: if upper left is invalid, none are valid + inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 0, 0); } + + // Move mask to next tile (flash attention) + inline __device__ void move() { this->col_ += Cta_tile::N; } + + // Move mask the col by offset (flash attention) + inline __device__ void move_to_offset(int offset) { this->col_ = col_init_ + offset; } + + // Reset mask to the initial col + inline __device__ void reset() { col_ = col_init_; } + + // Load the mask... Nothing to do for real. + inline __device__ void load(int) {} + + // Load the mask... we use it to keep track of to row, col (flash attention). + inline __device__ void load(int, int col_loop_step) { col_loop_step_ = col_loop_step; } + + // The length of the sequence. + int seqlen_; + // The left-most position of the thread in the sequence. + int col_, col_init_; + // The current col iteration + int col_loop_step_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask : public Mask { + // V3 mask is the causal mask (e.g. for GPT) and extends V2 masks (self-attention). + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), row_loop_step_(0) { + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_m = warp % Cta_tile::WARPS_M; + row_ = warp_m * 16 + lane / 4; + } + + inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const { + // The position of the thread in the sequence. + row = this->row_ + this->row_loop_step_ + mi * Mma_tile::M_PER_MMA_PER_CTA; + // The position inside the MMA. + row += ii * 8; + + // The position of the thread in the sequence. + col = this->col_ + this->col_loop_step_ * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + col += (jj & 0x02) * 4 + (jj & 0x1); + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence, i.e. are we in the lower triangle? + return (row >= col); + } + + // GPT Mask: if lower left is invalid, none are valid + inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 1, 0); } + + // Load the mask... we use it to keep track of to row. + inline __device__ void load(int row_loop_step) { row_loop_step_ = row_loop_step; } + + // Load the mask... we use it to keep track of to row, col (flash attention). + inline __device__ void load(int row_loop_step, int col_loop_step) { + row_loop_step_ = row_loop_step; + this->col_loop_step_ = col_loop_step; + } + + // The upper-most position of the thread in the sequence. + int row_; + // Current row step offset. + int row_loop_step_; +}; + +// Specialized mask for MTP (multi-token prediction used in MLA). +template +struct MtpMask : public Mask { + // MTP mask (causal mask) extends from V2 (dense) masks (self-attention). + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // Ctor. + template + inline __device__ MtpMask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), + num_grouped_heads_(params.num_grouped_heads), + row_loop_step_(0) { + // Update the seqlen (excluding all MTP draft tokens). + this->seqlen_ = this->seqlen_ - (block_info.actual_q_seqlen / params.num_grouped_heads) + 1; + + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_m = warp % Cta_tile::WARPS_M; + row_ = warp_m * 16 + lane / 4; + } + + inline __device__ int get_row(int mi, int ii) const { + // The position of the thread in the sequence. + int row = this->row_ + this->row_loop_step_ + mi * Mma_tile::M_PER_MMA_PER_CTA; + // The position inside the MMA. + row += ii * 8; + return row; + } + + inline __device__ int get_col(int ni, int jj) const { + // The position of the thread in the sequence. + int col = this->col_ + this->col_loop_step_ * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + col += (jj & 0x02) * 4 + (jj & 0x1); + return col; + } + + inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const { + row = get_row(mi, ii); + col = get_col(ni, jj); + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int col = get_col(ni, jj); + + // Is it a valid position in the sequence? + return col < (this->seqlen_ + mtp_token_idx_[mi][ii]); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence, i.e. are we in the lower triangle? + return (row >= col); + } + + // Load the mask... we use it to keep track of to row. + inline __device__ void load(int row_loop_step) { + row_loop_step_ = row_loop_step; +// Update the MTP token index. +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + mtp_token_idx_[mi][ii] = get_row(mi, ii) / num_grouped_heads_; + } + } + } + + // The number of grouped heads in the row dimension. + int num_grouped_heads_; + // The corresponding MTP token index for each row. + // FIXME: currently we assume 2 rows per thread (volta/hopper-gmma traits are not supported yet). + int mtp_token_idx_[Mma_tile::MMAS_M][2]; + // The upper-most position of the thread in the sequence. + int row_; + // The current row step offset. + int row_loop_step_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The lower triangle attention matrix. +// Assume we only pay attention to past sliding-window-size long sequence. +// v x x x x x x x x +// v v x x x x x x x +// v v v x x x x x x +// v v v v x x x x x +// v v v v v x x x x +// x v v v v v x x x +// x x v v v v v x x +// x x x v v v v v x +// x x x x v v v v v + +template +struct Mask : public Mask { + // V4 mask is the causal mask (e.g. for GPT) plus the sliding-window feature. + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), sliding_window_size_(params.sliding_window_size) {} + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + this->get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence, i.e. are we in the lower triangle? + return (row >= col) && (col >= max(0, row + 1 - sliding_window_size_)); + } + + // The sliding window size. + int sliding_window_size_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The custom mask (from global memory). +template +struct Mask : public Mask { + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in each dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // One 32-bit packed mask holds 4 MMAS_N as one group. + enum { MMA_GROUPS_N = fmha::Div_up::VALUE }; + + // The MMAS_N in the group. + enum { MMAS_N_IN_GROUP = fmha::Min::VALUE }; + + // MMAS_N uses full 32-bit integer packed masks. + enum { FULL_PACKED_MASK = (MMAS_N % 4 == 0) }; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), + packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), + params_packed_mask_stride_in_bytes_(params.packed_mask_stride_in_bytes), + row_offset_(0) { + // Add the thread offset in bytes. + packed_mask_ptr_ += + (block_info.sum_mask_row * params_packed_mask_stride_in_bytes_ + tidx * sizeof(uint32_t)); + } + + // Load the mask... we use it to keep track of row offset. + inline __device__ void load(int row_offset) { row_offset_ = row_offset; } + + // Load the mask into registers (and expand). + inline __device__ void load_mask(int col_offset) { + // The packed_mask_offset in the col(N) dimension. + int mask_col_offset = int(col_offset / (Mma_tile::N_PER_MMA_PER_CTA * 4)) * + Cta_tile::THREADS_PER_CTA * sizeof(uint32_t); + // When MMAS_N < 4, one loaded packed_mask can be expanded to boolean masks + // of multiple iterations. + int local_col = FULL_PACKED_MASK ? 0 : (col_offset % (Mma_tile::N_PER_MMA_PER_CTA * 4)); + // The local mma ni if MMAS_N < 4. + int local_ni = local_col / 16; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The M dimension offset. + int offset = + (row_offset_ + mi * Mma_tile::M_PER_MMA_PER_CTA) * params_packed_mask_stride_in_bytes_; + // The N dimension offset. + offset += mask_col_offset; + // Set predicate to true only when next 32-bit packed mask is needed. + bool pred = local_col == 0; +#pragma unroll + for (int ni = 0; ni < MMA_GROUPS_N; ++ni) { + // The MMAS_N group offset. + if (pred) { + fmha::ldg(packed_mask_[mi][ni], + packed_mask_ptr_ + offset + ni * Cta_tile::THREADS_PER_CTA * sizeof(uint32_t)); + } + } + } + +// Expand the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMA_GROUPS_N; ++ni) { +#pragma unroll + for (int nni = 0; nni < MMAS_N_IN_GROUP; ++nni) { + mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 0] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 0)); + mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 1] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 1)); + mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 0] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 2)); + mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 1] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 3)); + mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 2] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 4)); + mask_[2 * mi + 0][(ni * 4 + nni) * 4 + 3] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 5)); + mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 2] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 6)); + mask_[2 * mi + 1][(ni * 4 + nni) * 4 + 3] = + packed_mask_[mi][ni] & (1u << (8 * (nni + local_ni) + 7)); + } + } + } + } + + // Move mask the col by offset (flash attention) + inline __device__ void move_to_offset(int col_offset) { load_mask(col_offset); } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + return mask_[mi * 2 + ii][ni * 4 + jj]; + } + + // Current row step offset. + int row_offset_; + + // The pointer to the mask. + char const* packed_mask_ptr_; + // The stride in the n dimension. + int64_t const params_packed_mask_stride_in_bytes_; + // The packed mask (one 32-bit integer per MMA GROUP, MMAS_M * 2 rows, MMA_GROUPS_N * 16 cols). + uint32_t packed_mask_[MMAS_M][MMA_GROUPS_N]; + // The mask after expansion. + bool mask_[MMAS_M * 2][MMAS_N * 4]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask { + // The instruction traits. + using Traits = Volta_hmma_fp16_traits; + // The shape of the MMA tile. + using Mma_tile = typename Traits::Mma_tile; + + // That implementation works only when WARPS_K is 1. + static_assert(Cta_tile::WARPS_K == 1, ""); + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : seqlen_(block_info.actual_seqlen) { + // The decomposition of the thread index into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_n = warp / Cta_tile::WARPS_M; + // The position of the thread. + col_ = block_info.bidn * Cta_tile::N + warp_n * 16 + (lane & 0x08) / 2; + col_init_ = col_; + } + + // Is a given position valid? + inline __device__ bool is_valid(int, int ni, int, int jj) const { + // The position of the thread in the sequence. + int offset = this->col_ + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + offset += (jj & 0x04) * 2 + (jj & 0x03); + // Is it a valid position in the sequence? + return offset < seqlen_; + } + + // Load the mask... Nothing to do for real. + inline __device__ void load(int) {} + + // Reset mask to the initial col + inline __device__ void reset() { col_ = col_init_; } + + // Move mask to next tile (flash attention) + inline __device__ void move() { this->col_ += Cta_tile::N; } + + // Move mask the col by offset (flash attention) + inline __device__ void move_to_offset(int offset) { this->col_ = col_init_ + offset; } + + // The length of the sequence. + int const seqlen_; + // The left-most position of the thread in the sequence. + int col_, col_init_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask + : public Mask { + // V3 mask is the causal mask (e.g. for GPT) and extends V2 masks (self-attention). + using Base = Mask; + + // The shape of the MMA tile. + using Mma_tile = typename Base::Mma_tile; + + // Ctor. + template + inline __device__ Mask(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), loop_step_(0) { + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the warp. + int warp_m = warp % Cta_tile::WARPS_M; + row_ = warp_m * 16 + (lane & 0x07) + (lane & 0x10) / 2; + } + + inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const { + // The position of the thread in the sequence. + row = this->row_ + this->loop_step_ + mi * Mma_tile::M_PER_MMA_PER_CTA; + + // The position of the thread in the sequence. + col = this->col_ + ni * Mma_tile::N_PER_MMA_PER_CTA; + // The position inside the MMA. + col += (jj & 0x04) * 2 + (jj & 0x03); + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence, i.e. are we in the lower triangle? + return (row >= col) && (col < this->seqlen_); + } + + // GPT Mask: if lower left is invalid, none are valid + inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 0, 0); } + + // Load the mask... we use it to keep track of to row. + inline __device__ void load(int loop_step) { loop_step_ = loop_step; } + + // The upper-most position of the thread in the sequence. + int row_; + // Current iteration. + int loop_step_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask_dispatcher {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask_dispatcher + : public Mask { + using Base = Mask; + + template + inline __device__ Mask_dispatcher(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask_dispatcher : public MtpMask { + using Base = MtpMask; + + template + inline __device__ Mask_dispatcher(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Mask_hopper { + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // Ctor. + template + inline __device__ Mask_hopper(Params const& params, Block_info const& block_info, int tidx) + : seqlen_(block_info.actual_seqlen) { + // For Hopper the warp distribution is always 4x1 within a warpgroup. + // So maybe there is some assumptions/optimizations to be made here. + + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int warp_n = warp / 4; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + col_ = warp_n * Mma_tile::N_PER_WARP_GROUP + (lane % 4) * 2; + } + + // Is a given position valid? + inline __device__ bool is_valid(int, int ni, int, int jj) const { + // The position of the thread in the sequence. + int offset = this->col_ + ni * Mma_tile::N_PER_MMA; + // The position inside the MMA. + offset += (jj / 2) * 8 + (jj % 2); + // Is it a valid position in the sequence? + return offset < seqlen_; + } + + // Load the mask... Nothing to do for real. + inline __device__ void load(int) {} + + // The length of the sequence. + int const seqlen_; + // The left-most position of the thread in the sequence. + int col_; +}; + +template +struct Mask_hopper { + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // Ctor. + template + inline __device__ Mask_hopper(Params const& params, Block_info const& block_info, int tidx) { + // For Hopper the warp distribution is always 4x1 within a warpgroup. + // So maybe there is some assumptions/optimizations to be made here. + + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int warp_n = warp / 4; + int warp_m = warp % 4; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + col_ = warp_n * Mma_tile::N_PER_WARP_GROUP + (lane % 4) * 2; + row_base_ = warp_m * 16 + lane / 4; + row_ = row_base_; + } + + inline __device__ void get_row_col(int& row, int& col, int mi, int ni, int ii, int jj) const { + // The row position of the thread in the sequence. + row = row_ + mi * Mma_tile::M_PER_MMA + ii * 8; + + // The position of the thread in the sequence. + col = this->col_ + ni * Mma_tile::N_PER_MMA; + // The position inside the MMA. + col += (jj / 2) * 8 + (jj % 2); + } + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence? + return col <= row; + } + + // Load the mask... Nothing to do for real. + inline __device__ void load(int loop_step) { row_ = row_base_ + loop_step * Cta_tile::M; } + + // The left-most position of the thread in the sequence. + int row_, row_base_, col_; +}; + +template +struct Mask_hopper : public Mask_hopper { + // V4 mask is the causal mask (e.g. for GPT) plus the sliding-window feature. + using Base = Mask_hopper; + + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // Ctor. + template + inline __device__ Mask_hopper(Params const& params, Block_info const& block_info, int tidx) + : Base(params, block_info, tidx), sliding_window_size_(params.sliding_window_size) {} + + // Is a given position valid? + inline __device__ bool is_valid(int mi, int ni, int ii, int jj) const { + int row, col; + this->get_row_col(row, col, mi, ni, ii, jj); + + // Is it a valid position in the sequence? + return is_valid(row, col); + } + + // Is a given position valid? + inline __device__ bool is_valid(int row, int col) const { + // Is it a valid position in the sequence? + return col <= row && col >= max(0, row + 1 - sliding_window_size_); + } + + // The sliding window size for attention. + int sliding_window_size_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/numeric_types.h b/csrc/fmha_v2/fmha/numeric_types.h new file mode 100644 index 0000000000..1c3ec1a615 --- /dev/null +++ b/csrc/fmha_v2/fmha/numeric_types.h @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +#include + +#pragma once + +#if CUDART_VERSION >= 11080 +// TODO Better way? +#define FMHA_CUDA_SUPPORTS_FP8 true +#endif +#include +#if FMHA_CUDA_SUPPORTS_FP8 +#include +#endif +namespace fmha { + +using fp16_t = uint16_t; +using fp32_t = float; +using tf32_t = uint32_t; +using bf16_t = nv_bfloat16; +#if FMHA_CUDA_SUPPORTS_FP8 +using e4m3_t = __nv_fp8_e4m3; +using e5m2_t = __nv_fp8_e5m2; +#else +using e4m3_t = char; +using e5m2_t = char; +#endif + +static constexpr float MAX_E4M3 = 448.f; // 0x7E 2^8 * 1.75 +static constexpr float MAX_E5M2 = 57344.f; // 0x7B 2^15 * 1.75 + +template +__host__ __device__ constexpr inline float Softmax_fp_quant_scale(); + +template <> +__host__ __device__ constexpr inline float Softmax_fp_quant_scale() { + // Softmax has max output of 1.0, therefore we choose fp32-to-fp8 quantization scale as the + // largest power-of-2 below the e4m3 limit: + // 2^(floor(log2(E4M3_MAX / amax_exp_p))) = 2^(floor(log2(448 / 1))) = 2 ^ 8 + return 256.f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/paged_kv_cache.h b/csrc/fmha_v2/fmha/paged_kv_cache.h new file mode 100644 index 0000000000..a8e13a61d0 --- /dev/null +++ b/csrc/fmha_v2/fmha/paged_kv_cache.h @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include + +namespace fmha { + +// This needs to be aligned with the definition in TRT-LLM +struct Kv_block_array { + using PtrType = int32_t; + + // Maximum number of sequences supported by the kv-cache. + int32_t mMaxSeqs; + // Max number of blocks per sequence + int32_t mMaxBlocksPerSeq; + // Number of tokens. It must be power of 2. + int32_t mTokensPerBlock; + // Exponent of number of tokens with base 2. + // E.g. for mTokensPerBlock 64, mTokensPerBlockLog2 equals to 6 + int32_t mTokensPerBlockLog2; + // Table maps logical block idx to the data pointer of k/v cache block pool + // Shape [B, W, 2, M], where 2 is table for K and V, + // B is current number of sequences + // W is beam width + // M is Max number of blocks per sequence + + // Size of KV cache blocks in bytes (H*D*T*sizeof(DataType)) + int32_t mBytesPerBlock; + // Pointer to beginning of pool. + void* mPoolPtr; + // Pointer to block offsets. + PtrType* mBlockOffsets; + + Kv_block_array() = default; + + Kv_block_array(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, + int32_t bytesPerBlock, void* poolPtr) + : mMaxSeqs(batchSize), + mMaxBlocksPerSeq(maxBlocksPerSeq), + mTokensPerBlock(tokensPerBlock), + mBytesPerBlock{bytesPerBlock}, + mPoolPtr{poolPtr}, + mBlockOffsets{nullptr} { + float const tokensPerBlockSeqLog2 = log2(mTokensPerBlock); + mTokensPerBlockLog2 = static_cast(tokensPerBlockSeqLog2); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/smem_tile.h b/csrc/fmha_v2/fmha/smem_tile.h new file mode 100644 index 0000000000..dd75cf7bdb --- /dev/null +++ b/csrc/fmha_v2/fmha/smem_tile.h @@ -0,0 +1,2071 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The description of the tile computed by this CTA. + typename Cta_tile, + // The number of rows in the 2D shared memory buffer. + int M_, + // The number of cols. + int N_, + // The size in bits of each element. + int BITS_PER_ELEMENT_, + // The number of bytes per STS. + int BYTES_PER_STS_ = 16, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_ = 1, + // Do we enable the fast path for LDS.128 and friends. + int ENABLE_LDS_FAST_PATH_ = 0, + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + int ROWS_PER_XOR_PATTERN_ = 8, + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + int COLS_PER_XOR_PATTERN_ = 1, + // Use or not predicates + bool USE_PREDICATES_ = true, + // Use TMA or not, + bool USE_TMA_ = false, + // The leading dim elements in shared memory + int LEAD_DIM_ELEMENTS_ = N_> +struct Smem_tile_without_skews { + // The type of this tile + using Smem_tile_ = + Smem_tile_without_skews; + + static constexpr bool USE_TMA = USE_TMA_; + + // The size in bits of each element. + enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; + + // The size in bytes of a single STS. + enum { BYTES_PER_STS = BYTES_PER_STS_ }; + + // The number of elements per STS. + enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; + + // To support arbitrary N, we pad some values to a power-of-2. + enum { N_WITH_PADDING = Next_power_of_two::VALUE }; + + // The number of bytes per row without packing of rows. + enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; + + // The number of bytes per row -- we want at least 128B per row. + enum { BYTES_PER_ROW = Max::VALUE }; + + // The number of rows in shared memory (two rows may be packed into a single one). + enum { ROWS = M_ * N_ / LEAD_DIM_ELEMENTS_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; + + // The number of threads per row. + enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS }; + + // The number of threads per row. + enum { THREADS_PER_ROW = Min::VALUE }; + + // The number of STS per row. + enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; + + // It must be at least one. + static_assert(STS_PER_ROW >= 1, ""); + + // The number of rows written with a single STS. + enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;) + static_assert(ROWS_PER_STS >= 1, ""); + + // The number of STS needed to store all rows. + enum { STS_PER_COL = Div_up::VALUE }; + + // The number of STS in total. + enum { STS = STS_PER_COL * STS_PER_ROW }; + + // The size of one buffer in bytes in shared memory. + enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; + + // The number of buffers. + enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; + + // The size in bytes of total buffers. + enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; + + // The boundary for smem_read_offset and smem_write_offset increment. + enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; + + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; + + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; + + // Use or not predicates + enum { USE_PREDICATES = USE_PREDICATES_ }; + + // The bytes of one shmem row + enum { BYTES_PER_SHMEM_ROW = 128 }; + + // The type of elements that are stored in shared memory by each thread. + using Store_type = typename Uint_from_size_in_bytes::Type; + + // Ctor. + inline __device__ Smem_tile_without_skews(void* smem, int tidx) + : smem_(__nvvm_get_smem_pointer(smem)) { + // The row written by a thread. See doc/mma_smem_layout.xlsx. + int smem_write_row = tidx / THREADS_PER_ROW; + + // The XOR pattern. + int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN; + // Compute the column and apply the XOR pattern. + int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; + + // The offset. + this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS; + + // That code is expected to trigger the utilization of the URF by the compiler. + this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); + this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); + } + + // Compute the store pointers. + template + inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + // Decompose the STS into row/col. + int row = ii % STS_PER_COL; + int col = ii / STS_PER_COL; + + // Compute the immediate. + int imm = row; + + // Assemble the offset. + int offset = smem_write_offset_ + imm * ROWS_PER_STS * BYTES_PER_ROW; + + // Take the column into account. + if (STS_PER_ROW > 1) { + offset += col * THREADS_PER_ROW * BYTES_PER_STS; + } + + // Apply the XOR pattern if needed. + if (ROWS_PER_STS < ROWS_PER_XOR_PATTERN) { + int const m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN; + offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS; + } + +// Assemble the final pointer :) +#pragma unroll + for (int k = 0; k < K; k++) { + ptrs[ii * K + k] = smem_ + offset + k * (BYTES_PER_STS / K) + smem_write_buffer_; + } + } + } + + inline __device__ void debug_reset() { + for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for (int row = 0; row < ROWS; ++row) { + for (int col = 0; col < BYTES_PER_ROW; col += 4) { + if (threadIdx.x == 0) { + uint32_t val = 0x0; + sts(val, smem_ + row * BYTES_PER_ROW + col + buffer); + } + } + } + } + } + + // Print the content of the tile (only for debug ;)). + inline __device__ void debug_print() const { + for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { + for (int row = 0; row < ROWS; ++row) { + for (int col = 0; col < BYTES_PER_ROW; col += 4) { + if (threadIdx.x == 0) { + uint32_t val; + lds(val, smem_ + row * BYTES_PER_ROW + col + buffer); + printf( + "block=(x=%2d, y=%2d, z=%2d) (smem_=0x%08x, buffer=%2d, row=%2d, " + "byte=%4d)=0x%08x\n", + blockIdx.x, blockIdx.y, blockIdx.z, smem_, buffer, row, col, val); + } + } + } + } + } + + // Move the read offset to next buffer. + inline __device__ void move_to_next_read_buffer() { + if (BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) { + this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_read_buffer_ += BYTES_PER_BUFFER; + } + } + + // Move the read offset to next buffer. TODO: Remove this member function!!! + inline __device__ void move_next_read_buffer() { this->move_to_next_read_buffer(); } + + // Move the read offset to next N buffer (circular-buffer). + inline __device__ void move_to_next_read_buffer(int N) { + if (BUFFERS_PER_TILE > 1) { + this->smem_read_buffer_ += N * BYTES_PER_BUFFER; + this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; + } + } + + // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!! + inline __device__ void move_next_read_buffer(int N) { this->move_to_next_read_buffer(N); } + + // Move the write offset to next buffer. + inline __device__ void move_to_next_write_buffer() { + if (BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) { + this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; + } else if (BUFFERS_PER_TILE > 1) { + this->smem_write_buffer_ += BYTES_PER_BUFFER; + } + } + + // Move the write offset to next buffer. TODO: Remove that member function! + inline __device__ void move_next_write_buffer() { this->move_to_next_write_buffer(); } + + // Move the read offset. + inline __device__ void move_read_offset(int delta) { this->smem_read_offset_ += delta; } + + // Move the write offset. + inline __device__ void move_write_offset(int delta) { this->smem_write_offset_ += delta; } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N]) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N], uint32_t (&preds)[M]) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + sts(smem_ptrs, data, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(Store_type const (&data)[N], uint32_t preds) { + this->store(data, preds); + } + + // Store to the tile in shared memory. TODO: Remove last template arguments. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t (&preds)[M]) { + uint32_t smem_ptrs[N]; + this->compute_store_pointers(smem_ptrs); + ldgsts(smem_ptrs, gmem_ptrs, preds); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { + uint32_t tmp[1] = {preds}; + this->store(gmem_ptrs, tmp); + } + + // Store to the tile in shared memory. + template + inline __device__ void store(void const* (&gmem_ptrs)[N], uint32_t preds) { + uint32_t tmp[1] = {preds}; + this->store(gmem_ptrs, tmp); + } + + inline __device__ void add_smem_barrier_base(uint64_t*) {} + + // The shared memory pointer. + uint32_t smem_; + // The read offset. Reserve 4 offsets if needed. + int smem_read_offset_; + // The write offset. + int smem_write_offset_; + // The buffer base offset for read. + int smem_read_buffer_; + // The buffer base offset for write. + int smem_write_buffer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Use TMA +template < + // The description of the tile computed by this CTA. + typename Cta_tile, + // The number of rows in the 2D shared memory buffer. + int M_, + // The number of cols. + int N_, + // The size in bits of each element. + int BITS_PER_ELEMENT_, + // The number of bytes per STS. Not relevant for TMA + int BYTES_PER_STS_, + // The number of buffers. (Used in multistage and double buffer cases.) + int BUFFERS_PER_TILE_, + // Do we enable the fast path for LDS.128 and friends. + int ENABLE_LDS_FAST_PATH_, + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + int ROWS_PER_XOR_PATTERN_, + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + int COLS_PER_XOR_PATTERN_, + // Use or not predicates + bool USE_PREDICATES_, + // The leading dim elements in shared memory + int LEAD_DIM_ELEMENTS_> +struct Smem_tile_without_skews + : public Smem_tile_without_skews { + // Base struct + using Base = + Smem_tile_without_skews; + static constexpr bool USE_TMA = true; + + // Tile size overrides. STS per thread not relevant for TMA + static constexpr int BYTES_PER_BUFFER = M_ * N_ * Base::BITS_PER_ELEMENT / 8; + static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * Base::BUFFERS_PER_TILE; + static constexpr int BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER; + // The number of bytes per barrier + static constexpr int BYTES_PER_BARRIER = 8; + + // Ctor + inline __device__ Smem_tile_without_skews(void* smem, int tidx) : Base(smem, tidx) { + this->smem_write_offset_ = __nvvm_get_smem_pointer(smem); + this->smem_barrier_offset_ = 0; + this->elect_one_ = elect_one_sync(); + } + + inline __device__ void add_smem_barrier_base(uint64_t* smem_barrier) { + this->smem_barrier_ = smem_barrier; + this->smem_barrier_offset_ = __nvvm_get_smem_pointer(this->smem_barrier_); + } + + /** + * \brief load tensor blocks from global memory and stores to shared memory using tma instructions + * + * \param p_desc pointer to tma descriptor masked as const void* pointer + * \param smem_offset shared memory offset in bytes relative to smem_write_buffer_ + * \param coord0 tensor access coordinate in dimension 1, used by tma load + * \param coord1 tensor access coordinate in dimension 2, used by tma load + * \param coord2 tensor access coordinate in dimension 3, used by tma load + * \param coord3 tensor access coordinate in dimension 4, used by tma load + * \param coord4 tensor access coordinate in dimension 5, used by tma load + * \param filter_offsets encodes multicast cta id and filter offsets + */ + template + inline __device__ void store(void const* p_desc, unsigned const& smem_offset, int32_t coord0, + int32_t coord1, int32_t coord2, int32_t coord3, int32_t coord4, + uint16_t filter_offsets, uint16_t mcast_cta_mask, + uint64_t mem_desc) { + uint32_t smem = this->smem_write_offset_ + smem_offset; + fmha::utmaldg( + reinterpret_cast(p_desc), smem, unsigned(this->smem_barrier_offset_), + coord0, coord1, coord2, coord3, coord4, filter_offsets, mcast_cta_mask, mem_desc, + this->elect_one_); + } + + // Same function as above but for runtime cga dimension + template + inline __device__ void store(void const* p_desc, unsigned const& smem_offset, int32_t coord0, + int32_t coord1, int32_t coord2, int32_t coord3, int32_t coord4, + uint16_t filter_offsets, uint16_t mcast_cta_mask, uint64_t mem_desc, + bool mcast_enabled) { + uint32_t smem = this->smem_write_offset_ + smem_offset; + fmha::utmaldg(reinterpret_cast(p_desc), smem, + unsigned(this->smem_barrier_offset_), coord0, coord1, coord2, + coord3, coord4, filter_offsets, mcast_cta_mask, mcast_enabled, + mem_desc, this->elect_one_); + } + + // Move the write offset to next buffer. + inline __device__ void move_next_write_buffer() { + if (Base::BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ += (this->smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) + ? -BYTES_PER_TILE_INC_BOUNDARY + : BYTES_PER_BUFFER; + this->smem_barrier_offset_ += + (this->smem_barrier_offset_ >= Base::BUFFERS_PER_TILE * BYTES_PER_BARRIER) + ? -Base::BUFFERS_PER_TILE * BYTES_PER_BARRIER + : BYTES_PER_BARRIER; + } + } + + inline __device__ void move_next_write_buffer(int buffer_id) { + if (Base::BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ = this->smem_ + buffer_id * BYTES_PER_BUFFER; + } + this->smem_barrier_offset_ = __nvvm_get_smem_pointer(this->smem_barrier_ + buffer_id); + } + + // Move the read offset to next buffer. + // do nothing, as it is controlled by gmma desc + inline __device__ void move_next_read_buffer() {} + + uint64_t* smem_barrier_; + uint32_t smem_barrier_offset_; + // elect one thread to issue utmaldg + uint32_t elect_one_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The layout of the tile. + typename Layout, + // The size of the STS. + int BYTES_PER_STS = 16, + // The number of buffers per tile. + int BUFFERS_PER_TILE = 1, + // Use or not predicates + bool USE_PREDICATES = true> +struct Smem_tile_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_volta_a { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 1 : (N_IN_BITS <= 512 ? 2 : 4) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_reset_mask { + // The potential mask. + enum { HALF = MMAS_K_WITH_PADDING / 2 }; + + // The remainder. + enum { MOD = MMAS_K % HALF }; + + // The final value. + enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> { + enum { VALUE = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_reset_mask { + enum { VALUE = MMAS_K - 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_volta_a::VALUE> +struct Smem_tile_volta_row_a + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_a; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_volta_row_a(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/xmma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_M = Warp_masks::M; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + if (Base::N_WITH_PADDING >= 64) { + smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + + (tidx & 0x10) / 2 + (tidx & 0x07); + smem_read_col = (tidx & 0x03); + } else if (Base::N_WITH_PADDING == 32) { + smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + + (tidx & 0x10) / 4 + (tidx & 0x06) / 2; + smem_read_col = (tidx & 0x02) / 2 + (tidx & 0x01) * 4; + } else { + assert(false); + } + + // For WARPS_K > 1, we do not support Base::N_WITH_PADDING < 64 for the moment. + static_assert(WARPS_K <= 2 && (WARPS_K == 1 || Base::N_WITH_PADDING >= 64), ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop.- + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Move the offset to the next position. See doc/xmma_smem_layout.xlsx. + this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { + // Jump over as many rows as needed. + int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // TODO: Could we fuse smem_read_buffer and smem_read_offset? + uint4 tmp; + lds(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + a[mi].reg(0) = tmp.x; + a[mi].reg(1) = tmp.y; + a[mi].reg(2) = tmp.z; + a[mi].reg(3) = tmp.w; + } + + // Move the offset to the next position. See doc/xmma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_volta_row_a { + // The traits class. + using Traits = fmha::Volta_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_volta_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_turing_a { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 128 ? 1 : (N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8)) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_turing_a::VALUE> +struct Smem_tile_turing_row_a + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = + Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_a; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_turing_row_a(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_M = Warp_masks::M; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 1, + ""); + + if (Base::ROWS_PER_XOR_PATTERN == 8) { + smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x0f); + smem_read_col = (tidx & 0x07); + } else if (Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x0e) / 2; + smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4; + } else if (Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 4 + (tidx & 0x0c) / 4; + smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2; + } else if (Base::ROWS_PER_XOR_PATTERN == 1) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 8 + (tidx & 0x1f) / 8; + smem_read_col = (tidx & 0x07); + } + + static_assert(WARPS_K <= 2, ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop.- + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { + int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + uint2 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + a[mi].reg(0) = tmp.x; + a[mi].reg(1) = tmp.y; + } + + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_turing_row_a { + // The traits class. + using Traits = Turing_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_turing_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_turing_row_a { + // The traits class. + using Traits = Turing_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_turing_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_turing_row_a { + // The traits class. + using Traits = Turing_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_turing_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_a { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_A }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_row_a : public Rows_per_xor_pattern_ampere_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_row_a::VALUE> +struct Smem_tile_ampere_row_a + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = + Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_a; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_ampere_row_a(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_M = Warp_masks::M; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 2, + ""); + + if (Base::ROWS_PER_XOR_PATTERN == 8) { + smem_read_row = (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 1 + (tidx & 0x0f); + smem_read_col = (tidx & 0x07); + smem_read_col ^= (tidx & 0x10) / 16; + } else if (Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 2 + (tidx & 0x0e) / 2; + smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4; + smem_read_col ^= (tidx & 0x10) / 16; + } else if (Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = + (tidx & WARP_MASK_M) / WARP_DIV_M * Mma_tile::M_PER_MMA / 4 + (tidx & 0x0c) / 4; + smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2; + smem_read_col ^= (tidx & 0x10) / 16; + } + + static_assert(WARPS_K <= 2, ""); + static_assert(WARPS_K != 2 || Base::ROWS_PER_XOR_PATTERN != 2, ""); + + // We "swap" the block for the second warp working on the same outputs in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K * 2; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { + if (ki < Mma_tile::VALID_MMAS_K) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). + int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + + // Store the value into the fragment. + a[mi].reg(0) = tmp.x; + a[mi].reg(1) = tmp.y; + a[mi].reg(2) = tmp.z; + a[mi].reg(3) = tmp.w; + } + } + + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ampere_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ampere_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ampere_hmma_bf16_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ampere_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_a + : public Smem_tile_ampere_row_a { + // The traits class. + using Traits = Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Smem_tile_ampere_row_a; + + // Ctor. + inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The layout of the tile. + typename Layout, + // The size of the STS. + int BYTES_PER_STS = 16, + // The number of buffers per tile. + int BUFFERS_PER_TILE = 1, + // Use or not predicates + bool USE_PREDICATES = true> +struct Smem_tile_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_volta_b { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 1 : (N_IN_BITS <= 512 ? 2 : 4) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_volta_b::VALUE> +struct Smem_tile_volta_col_b + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = Smem_tile_without_skews; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + // The fragment. + using Fragment = Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_volta_col_b(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/xmma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_N = Warp_masks::N; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + if (Base::N_WITH_PADDING >= 64) { + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + + (tidx & 0x18) / 2 + (tidx & 0x03); + smem_read_col = (tidx & 0x03); + } else if (Base::N_WITH_PADDING == 32) { + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + + (tidx & 0x18) / 4 + (tidx & 0x02) / 2; + smem_read_col = (tidx & 0x02) / 2 + (tidx & 0x01) * 4; + } else { + assert(false); + } + + // For WARPS_K > 1, we do not support Base::N_WITH_PADDING < 64 for the moment. + static_assert(WARPS_K <= 2 && (WARPS_K == 1 || Base::N_WITH_PADDING >= 64), ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop.- + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Move the offset to the next position. See doc/xmma_smem_layout.xlsx. + this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Jump over as many rows as needed. + int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // TODO: Can we fuse read_offset and read_buffer? + uint4 tmp; + lds(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + + // Move the offset to the next position. See doc/xmma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_volta_col_b { + // The traits class. + using Traits = fmha::Volta_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_volta_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_turing_b { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 128 ? 1 : (N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8)) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_turing_b::VALUE> +struct Smem_tile_turing_col_b + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = + Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_turing_col_b(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_N = Warp_masks::N; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 1, + ""); + + if (Base::ROWS_PER_XOR_PATTERN == 8) { + // For group fprop. B is divided into 2 halves along N dimension. + // The fist warp takes the first half and the second warp takes the second half. + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x0f); + smem_read_col = (tidx & 0x07); + } else if (Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = + (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + (tidx & 0x0e) / 2; + smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4; + } else if (Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = + (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 4 + (tidx & 0x0c) / 4; + smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2; + } else if (Base::ROWS_PER_XOR_PATTERN == 1) { + smem_read_row = + (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 8 + (tidx & 0x1f) / 8; + smem_read_col = (tidx & 0x07); + } + + static_assert(WARPS_K <= 2, ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop.- + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + this->smem_read_offset_ ^= ((ki % 2 == 0) ? 1 : 3) * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + uint2 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + } + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_turing_col_b { + // The traits class. + using Traits = Turing_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_turing_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_turing_col_b { + // The traits class. + using Traits = Turing_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_turing_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_turing_col_b { + // The traits class. + using Traits = Turing_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_turing_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_b { + // The size in bits. + enum { N_IN_BITS = N * Traits::BITS_PER_ELEMENT_B }; + + // The number of rows. + enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_col_b : public Rows_per_xor_pattern_ampere_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_col_b::VALUE> +struct Smem_tile_ampere_col_b + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = + Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b; + + // When we use padding to reach a power of two, special care has to be taken. + using Cta_tile_with_padding = Cta_tile_with_k_with_padding; + // The number of MMAs. + using Mma_tile_with_padding = typename Traits::template Mma_tile; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // The number of STS per thread + enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; + + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_ampere_col_b(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/mma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_N = Warp_masks::N; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row and column read by the thread. + int smem_read_row, smem_read_col; + + static_assert(Base::ROWS_PER_XOR_PATTERN == 8 || Base::ROWS_PER_XOR_PATTERN == 4 || + Base::ROWS_PER_XOR_PATTERN == 2, + ""); + + if (Base::ROWS_PER_XOR_PATTERN == 8) { + // For group fprop. B is divided into 2 halves along N dimension. + // The fist warp takes the first half and the second warp takes the second half. + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 1 + (tidx & 0x07) + + (tidx & 0x10) / 2; + smem_read_col = (tidx & 0x07); + smem_read_col ^= (tidx & 0x08) / 8; + } else if (Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 2 + + (tidx & 0x06) / 2 + (tidx & 0x10) / 4; + smem_read_col = (tidx & 0x06) / 2 + (tidx & 0x01) * 4; + smem_read_col ^= (tidx & 0x08) / 8; + } else if (Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA / 4 + + (tidx & 0x04) / 4 + (tidx & 0x10) / 8; + smem_read_col = (tidx & 0x04) / 4 + (tidx & 0x03) * 2; + smem_read_col ^= (tidx & 0x08) / 8; + } + + static_assert(WARPS_K <= 2, ""); + static_assert(WARPS_K != 2 || Base::ROWS_PER_XOR_PATTERN != 2, ""); + + // We "swap" the block for the second warp working on the in-CTA split-K. + if (WARPS_K == 2) { + smem_read_col ^= (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile_with_padding::MMAS_K * 2; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { + if (ki < Mma_tile::VALID_MMAS_K) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). + int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; + + // Load using LDSM.M88.4. + uint4 tmp; + ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); + + // Store the value into the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + } + + // Move the offset to the next position. See doc/mma_smem_layout.xlsx. + static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); + if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile_with_padding::MMAS_K >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + + // Reset the read offset. + inline __device__ void reset_read_offset() { + // The number of MMAs in the K dimension. + enum { MMAS_K = Mma_tile::MMAS_K }; + + // The number of MMAs in the K dimension when we include padding. + enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; + + // Assemble the mask. + enum { MASK = Compute_reset_mask::VALUE }; + + // Reset the read offset. + this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ampere_hmma_fp16_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ampere_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ampere_hmma_bf16_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ampere_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_col_b { + // The traits class. + using Traits = Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Smem_tile_ampere_col_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rows_per_xor_pattern_ampere_row_b : public Rows_per_xor_pattern_ampere_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The instruction traits. + typename Traits, + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE, + // How many rows to use for the XOR pattern to avoid bank conflicts? + int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_ampere_row_b::VALUE, + // How many cols to use for the XOR pattern to avoid bank conflicts? + int COLS_PER_XOR_PATTERN_ = 1> +struct Smem_tile_ampere_row_b + : public Smem_tile_without_skews { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The base class. + using Base = Smem_tile_without_skews; + // The fragment. + using Fragment = Fragment_b; + + // Can we use LDSM? No if the data type is 32-bit large. + enum { USE_LDSMT = Traits::BITS_PER_ELEMENT_B == 16 }; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 }; + + // The number of elements per LDS. + enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / Traits::BITS_PER_ELEMENT_B }; + + // The number of STS per thread + enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; + + // The number of STS per thread must be at least 1. + enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; + + // Ctor. + inline __device__ Smem_tile_ampere_row_b(void* smem, int tidx) : Base(smem, tidx) { + // For documentation on the layout, see doc/xmma_smem_layout.xlsx. + + // The number of warps. + int const WARPS_M = Cta_tile::WARPS_M; + int const WARPS_N = Cta_tile::WARPS_N; + int const WARPS_K = Cta_tile::WARPS_K; + + // The masks to select the warps. + int const WARP_MASK_N = Warp_masks::N; + int const WARP_MASK_K = Warp_masks::K; + + // The divisor for the warps. + int const WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + int const WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // The row/col read by the thread. + int smem_read_row, smem_read_col; + + static_assert((USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 8) || + Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 2, + ""); + + if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 8) { + // For group dgrad. B is divided into 2 halves along K dimension. + // The fist warp takes the first half and the second warp takes the second half. + smem_read_row = + (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + (tidx & 0x07) + (tidx & 0x08); + smem_read_col = (tidx & 0x07); + } else if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 4) { + smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 8 + (tidx & 0x06) / 2 + + (tidx & 0x08) / 2; + smem_read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + } else if (USE_LDSMT && Base::ROWS_PER_XOR_PATTERN == 2) { + smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 4 + (tidx & 0x04) / 4 + + (tidx & 0x08) / 4; + smem_read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + } else if (Base::ROWS_PER_XOR_PATTERN == 4 && Base::COLS_PER_XOR_PATTERN == 2) { + smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 8 + (tidx & 0x03); + smem_read_col = (tidx & 0x1c) / 4 + (tidx & 0x03) * 8; + } + + // Each half-warp applies a different XOR pattern -- see the Excel document. + if (USE_LDSMT) { + smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16; + } else { + smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 16; + } + + // The shared memory offset. + this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS; + + // Fill zeroes for group conv + } + + // Rewind smem_read_offset for last LDS phase in main loop. + inline __device__ void reverse_smem_read_offset(int ki = 0) { + // The size of each element in bits. + int const BITS_PER_ELT = Traits::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + int const BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Undo the pointer increment for the next ni. + // Should match the load function below for ki = 0. + if (BYTES_PER_MMA_PER_CTA >= 128) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } + } + + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { + // The size of each element in bits. + int const BITS_PER_ELT = Traits::BITS_PER_ELEMENT_B; + // The size in bytes of the data needed to compute an MMA per CTA. + int const BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // Prepare the offset. + int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW; + if (BYTES_PER_MMA_PER_CTA == 32) { + offset += this->smem_read_offset_; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + offset += this->smem_read_offset_ + (ni / 2) * BYTES_PER_MMA_PER_CTA * 2; + } else { + offset += this->smem_read_offset_ + (ni)*BYTES_PER_MMA_PER_CTA; + } + + // Load the data using LDSM.MT88.2. + uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset; + + if (ni < Mma_tile::VALID_MMAS_N) { + uint4 tmp; + if (USE_LDSMT) { + ldsmt(tmp, ptr); + } else { + lds(tmp.x, (ptr) + 0 * Base::BYTES_PER_ROW); + lds(tmp.y, (ptr) + 4 * Base::BYTES_PER_ROW); + lds(tmp.z, (ptr ^ 32) + 0 * Base::BYTES_PER_ROW); + lds(tmp.w, (ptr ^ 32) + 4 * Base::BYTES_PER_ROW); + } + + // Store those values in the fragment. + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + + // static_assert(BYTES_PER_MMA_PER_CTA >= 128 || + // BYTES_PER_MMA_PER_CTA == 64 || + // (BYTES_PER_MMA_PER_CTA == 32 && + // (Mma_tile::MMAS_M == 4 || + // Mma_tile::MMAS_M == 2 || + // Mma_tile::MMAS_M == 1)), ""); + + // Move the pointer for the next ni. I expect the compiler to not recompute those. + if (BYTES_PER_MMA_PER_CTA >= 128) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } else if (BYTES_PER_MMA_PER_CTA == 64) { + // Nothing to do! + } else if (BYTES_PER_MMA_PER_CTA == 32) { + if ((ni & 1) == 0) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 16 && (ni & 7) == 7) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 30; + } else if (Mma_tile::MMAS_N >= 8 && (ni & 3) == 3) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 14; + } else if (Mma_tile::MMAS_N >= 4 && (ni & 1) == 1) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 6; + } + } + } + + // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) + if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) { + this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_row_b { + // The traits class. + using Traits = Ampere_hmma_fp32_traits; + // The base class. + using Base = Smem_tile_ampere_row_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The dimensions of the tile computed by the CTA. + typename Cta_tile, + // The size of the STS. + int BYTES_PER_STS, + // The number of buffers per tile. + int BUFFERS_PER_TILE> +struct Smem_tile_b + : public Smem_tile_ampere_row_b { + // The traits class. + using Traits = Ampere_hmma_bf16_traits; + // The base class. + using Base = Smem_tile_ampere_row_b; + + // Ctor. + inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/smem_tile_o.h b/csrc/fmha_v2/fmha/smem_tile_o.h new file mode 100644 index 0000000000..af7311a111 --- /dev/null +++ b/csrc/fmha_v2/fmha/smem_tile_o.h @@ -0,0 +1,1646 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// H M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o { + // The instruction traits. + using Traits = Volta_hmma_fp16_16x16x16_traits; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The accumulators. + using Data_type = typename Accumulator::Data_type; + + // The size of each element. + enum { BYTES_PER_ELEMENT = sizeof(Data_type) }; + + // The size of each STS. + enum { BYTES_PER_STS = 16 }; + + // The size of each row in shared memory. + enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * 2 * BYTES_PER_ELEMENT }; + + // The size of each LDS. + enum { BYTES_PER_LDS = 16 }; + + // The number of threads (to produce 16B per LDS). + enum { THREADS_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT / BYTES_PER_LDS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loops. + enum { LOOPS = fmha::Div_up::VALUE }; + + // Make sure it matches our expectations. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; + + // The amount of shared memory. + enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW }; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The row/col written by the thread. + int write_row, write_col; + + // SEQLEN == 128. Segments of 128B are written by 2 warps. + if (WARPS_2x1x2 && Cta_tile::N == 32) { + write_row = (tidx & 0x30) / 2 + (tidx & 0x07); + write_col = (tidx & 0x0f); + write_col ^= (tidx & 0x40) / 16; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x30) / 2 + (tidx & 0x07); + write_col = (tidx & 0x40) / 8 + (tidx & 0x08) * 2 + (tidx & 0x07); + + // SEQLEN == 256, 384 and N == 32. Segments of 128B are written by 2 warps. + } else if (WARPS_1x1x4 && Cta_tile::N == 32) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x40) / 8 + (tidx & 0x08) * 2 + (tidx & 0x07); + write_col ^= (tidx & 0x20) / 8; + + // SEQLEN == 256, 384 and N == 64. + } else if (WARPS_1x1x4 && Cta_tile::N == 64) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 4 + (tidx & 0x08) * 4 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and N == 128. + } else if (WARPS_1x1x4 && Cta_tile::N == 128) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 2 + (tidx & 0x08) * 8 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and N == 256. + } else if (WARPS_1x1x4 && Cta_tile::N == 256) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0x60) / 1 + (tidx & 0x08) * 16 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and N == 32. Segments of 128B are written by 2 warps. + } else if (WARPS_1x1x8 && Cta_tile::N == 32) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0xc0) / 8 + (tidx & 0x08) * 4 + (tidx & 0x07); + write_col ^= (tidx & 0x20) / 8; + + // SEQLEN == 256, 384, 512 and N == 64. + } else if (WARPS_1x1x8 && Cta_tile::N == 64) { + write_row = (tidx & 0x10) / 2 + (tidx & 0x07); + write_col = (tidx & 0xe0) / 4 + (tidx & 0x08) * 8 + (tidx & 0x07); + + // ANY SEQLEN and N == 32 + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + write_row = (tidx & 0xf0) / 2 + (tidx & 0x07); + write_col = (tidx & 0x07); + write_col ^= (tidx & 0x08) / 2; + + // ANY SEQLEN and N == 64 + } else if (WARPS_4x1x1 && Cta_tile::N == 64) { + write_row = (tidx & 0x70) / 2 + (tidx & 0x07); + write_col = (tidx & 0x0f); + + // ANY SEQLEN and N == 128 + } else if (WARPS_4x1x1 && Cta_tile::N == 128) { + write_row = (tidx & 0x70) / 2 + (tidx & 0x07); + write_col = (tidx & 0x08) + (tidx & 0x0f); + + // ANY SEQLEN and N == 256 + } else if (WARPS_4x1x1 && Cta_tile::N == 256) { + write_row = (tidx & 0x70) / 2 + (tidx & 0x07); + write_col = (tidx & 0x08) * 3 + (tidx & 0x0f); + + // Not supported. + } else { + assert(false); + } + + // Assemble the write pointer. + smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Take the XOR pattern into account for the column. + read_col ^= read_row & 0x7; + + // Assemble the read pointer. + smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; + } + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { + uint32_t local_smem_read_ = smem_read_; +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Apply the XOR pattern if needed. (XOR 8 default) + if (ROWS_PER_LDS < 8) { + local_smem_read_ = (smem_read_ ^ ((ii * ROWS_PER_LDS) % 8 * BYTES_PER_LDS)); + } + + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K * 2]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K * 2; ++jj) { + // The immediate. + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW; + if (Cta_tile::N == 256) { + imm += jj * 512; + } else if (Cta_tile::N == 128) { + imm += jj * 256; + } else if (Cta_tile::N == 64) { + imm += jj * 128; + } else if (Cta_tile::N == 32) { + imm += jj / 2 * 128; + } else { + assert(false); + } + + // The XOR mask. + int smem_read_offset = local_smem_read_; + if (Cta_tile::N == 32 && (jj % 2) == 1) { + smem_read_offset ^= 64; + } + + // Load... + if (!HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || is_active_for_last_lds_)) { + fmha::lds(tmp[jj], smem_read_offset + imm); + } + } + + // Perform the reduction. + out[ii] = tmp[0]; +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K * 2; ++jj) { + out[ii] = fmha::hadd8(out[ii], tmp[jj]); + } + } + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::VALID_MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + +// Store 1st column of the different MMAs. +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Assemble the vectors for the stores. See how we swizzle the registers. + uint4 tmp_0; + tmp_0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); + tmp_0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); + tmp_0.z = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); + tmp_0.w = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); + + uint4 tmp_1; + tmp_1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); + tmp_1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); + tmp_1.z = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); + tmp_1.w = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); + + // Precompute the immediates to jump to the correct row. + int row = mj * M_PER_MMA * BYTES_PER_ROW; + + // The columns. + int smem_write_0 = smem_write_ ^ ((2 * ni + 0) * BYTES_PER_STS); + int smem_write_1 = smem_write_ ^ ((2 * ni + 1) * BYTES_PER_STS); + + // Store. + fmha::sts(smem_write_0 + row, tmp_0); + fmha::sts(smem_write_1 + row, tmp_1); + } + } + } + + // The write pointer. + uint32_t smem_write_; + // The write pointer. + uint32_t smem_read_; + // Is the thread active for the last LDS of the series? + int is_active_for_last_lds_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// This class converts the FP16/FP32 inputs to FP16x2. + +struct Convert_from_fp16 { + // Convert one pair of fp16 numbers. + template + static inline __device__ uint32_t convert(Accumulators const& acc, int ii) { + // Extract the 2x FP16 numbers (packed in a register). + uint32_t h2 = acc.reg(ii); + + return h2; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Convert_from_fp32 { + // Convert one pair of fp16 numbers. + template + inline __device__ uint32_t convert(Accumulators const& acc, int ii) { + // Extract the 2x floats. + float f0 = acc.elt(ii * 2 + 0); + float f1 = acc.elt(ii * 2 + 1); + + // Convert to FP16x2. + return fmha::float2_to_half2(f0, f1); + } + + // The bf16 accumulators (convert from fp32 to 2xbf16). + using Ampere_bf16_Accumulator = fmha::Fragment_accumulator; + + static inline __device__ uint32_t convert(Ampere_bf16_Accumulator const& acc, int ii) { + // Extract the 2x floats. + float f0 = acc.elt(ii * 2 + 0); + float f1 = acc.elt(ii * 2 + 1); + + // Convert to FP16x2. + return fmha::float2_to_bf16_x2(f0, f1); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hmma_smem_tile_o { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The data type. + using Data_type = typename Accumulator::Data_type; + // The epilogue data type + using Epilogue_type = typename Traits::Epilogue_type; + + // The size of each element. + enum { BYTES_PER_ELEMENT = sizeof(Epilogue_type) }; + + // The amount of bytes per row (without packing or split-k). + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The size of each STS. + enum { BYTES_PER_STS = BYTES_PER_STS_ }; + + // The size of each LDS. + enum { BYTES_PER_LDS = 16 }; + + // The number of threads (to produce 16B per LDS). + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The number of rows in shared memory. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loops. + enum { LOOPS = fmha::Div_up::VALUE }; + + // Make sure it matches our expectations. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; + + // The amount of shared memory. + enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW * Cta_tile::WARPS_K }; + + // The amount of row packing to make sure we have at least 128B per smem row (without split-k). + enum { ROW_PACKING = Max<1, 128 / BYTES_PER_ROW>::VALUE }; + + // Make sure our row packing is correct + static_assert(ROWS_PER_LOOP % ROW_PACKING == 0, ""); + + // The amount of shared memory per row after packing. + enum { BYTES_PER_ROW_WITH_PACKING = BYTES_PER_ROW * ROW_PACKING }; + + // Make sure we have at least 128B per row after packing. + static_assert(BYTES_PER_ROW_WITH_PACKING >= 128, ""); + + // The number of threads per row after packing. + enum { THREADS_PER_ROW_WITH_PACKING = THREADS_PER_ROW * ROW_PACKING }; + + // Make sure we have at least 8 threads per row after packing. + static_assert(THREADS_PER_ROW_WITH_PACKING >= 8, ""); + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + enum { WARPS_4x1x2 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2 }; + + // Ctor. + inline __device__ Hmma_smem_tile_o(void* smem, int tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The row/col written by the thread. + int write_row, write_col; + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + write_row = (tidx & 0x20) / 8 + (tidx & 0x10) / 16; + write_col = (tidx & 0x40) / 2 + (tidx & 0x0c) * 2 + (tidx & 0x03); + write_col ^= (tidx & 0x10) / 4; + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + write_row = (tidx & 0x20) / 4 + (tidx & 0x18) / 8; + write_col = (tidx & 0x40) / 2 + (tidx & 0x04) * 4 + (tidx & 0x03); + write_col ^= (tidx & 0x18) / 2; + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x20) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x40) / 2 + (tidx & 0x03); + write_col ^= (tidx & 0x1c); + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 128. + } else if (WARPS_2x1x2 && Cta_tile::N == 128) { + write_row = (tidx & 0x20) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x40) / 1 + (tidx & 0x1f); + + // SEQLEN == 256, 384, 512 and HIDDEN_SIZE_PER_HEAD == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + write_row = (tidx & 0x10) / 16; + write_col = (tidx & 0x0c) * 2 + (tidx & 0xe3); + write_col ^= (tidx & 0x10) / 4; + + // SEQLEN == 256, 384, 512 and HIDDEN_SIZE_PER_HEAD == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + write_row = (tidx & 0x18) / 8; + write_col = (tidx & 0x04) * 4 + (tidx & 0xe3); + write_col ^= (tidx & 0x18) / 2; + + // SEQLEN == 256, 384 and HIDDEN_SIZE_PER_HEAD == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xff); + + // SEQLEN == 256, 384 and HIDDEN_SIZE_PER_HEAD == 128. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 128) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) * 2 + (tidx & 0x1f); + + // SEQLEN == 256, 384 and HIDDEN_SIZE_PER_HEAD == 256. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 256) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) * 4 + (tidx & 0x1f); + + // ANY SEQLEN and HIDDEN_SIZE_PER_HEAD == 16. + } else if (WARPS_4x1x1 && Cta_tile::N == 16) { + write_row = (tidx & 0xe0) / 8 + (tidx & 0x10) / 16; + write_col = (tidx & 0x0c) * 2 + (tidx & 0x03); + write_col ^= (tidx & 0x10) / 4; + + // ANY SEQLEN and HIDDEN_SIZE_PER_HEAD == 32. + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + write_row = (tidx & 0xe0) / 4 + (tidx & 0x18) / 8; + write_col = (tidx & 0x04) * 4 + (tidx & 0x03); + write_col ^= (tidx & 0x18) / 2; + + // ANY SEQLEN and HIDDEN_SIZE_PER_HEAD == 64/128. + } else if (WARPS_4x1x1 && (Cta_tile::N == 64 || Cta_tile::N == 128)) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x1f); + + // ANY SEQLEN and HIDDEN_SIZE_PER_HEAD == 256. + } else if (WARPS_4x1x1 && (Cta_tile::N == 256 || Cta_tile::N == 512)) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x1f); + + // GMMA: S=284/512 and HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_4x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x80) / 4 + (tidx & 0x03); + write_col ^= (tidx & 0x1c); + + // GMMA: S=284/512 and HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_4x1x2 && Cta_tile::N == 32) { + write_row = (tidx & 0x60) / 4 + (tidx & 0x1c) / 8; + write_col = (tidx & 0x80) / 4 + (tidx & 0x04) * 4 + (tidx & 0x03); + write_col ^= (tidx & 0x18) / 2; + + // Not supported. + } else { + assert(false); + } + + // Assemble the write pointer. + smem_write_ = smem_ + write_row * BYTES_PER_ROW_WITH_PACKING * Cta_tile::WARPS_K + + write_col * BYTES_PER_STS; + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < ROWS_PER_LOOP; + } + + // The XOR params. + int const XOR_MOD = 8 / ROW_PACKING; + + // Take the XOR pattern and the packing into account for the column. + read_col += read_row % ROW_PACKING * XOR_MOD; + read_row /= ROW_PACKING; + read_col ^= read_row % XOR_MOD; + + // Assemble the read pointer. + smem_read_ = smem_ + read_row * BYTES_PER_ROW_WITH_PACKING * Cta_tile::WARPS_K + + read_col * BYTES_PER_LDS; + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { + uint32_t local_smem_read_ = smem_read_; +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Apply the XOR pattern if needed. (XOR 8 default) + if (ROWS_PER_LDS < 8) { + local_smem_read_ = (smem_read_ ^ ((ii * ROWS_PER_LDS) % 8 * BYTES_PER_LDS)); + } + + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + // Note: ROWS_PER_LDS does not take packing into account - hence BYTES_PER_ROW. + int imm = + ii * ROWS_PER_LDS * BYTES_PER_ROW * Cta_tile::WARPS_K + jj * BYTES_PER_ROW_WITH_PACKING; + + // Load... + if (!HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || is_active_for_last_lds_)) { + fmha::lds(tmp[jj], local_smem_read_ + imm); + } + } + + // Perform the reduction. + out[ii] = tmp[0]; +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + out[ii] = fmha::add8(out[ii], tmp[jj]); + } + } + } + + // Store the accumulators. + template + inline __device__ void store_(Accumulators const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + + Converter converter; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + // Store 1st column of the different MMAs. + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + // The values (2 halves per register). + uint32_t h0 = converter.convert(acc[mi * MMAS_M_PER_LOOP + mj][ni], 0); + uint32_t h1 = converter.convert(acc[mi * MMAS_M_PER_LOOP + mj][ni], 1); + + // Store to shared memory. + fmha::sts(smem_write_ + row_0, h0); + fmha::sts(smem_write_ + row_1, h1); + } + } + + // Swizzle the write pointer using a XOR of 16B. + smem_write_ ^= 16; + + // Store 2nd column of the different MMAs. + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K; + + // The values (2 halves per register). + uint32_t h2 = converter.convert(acc[mi * MMAS_M_PER_LOOP + mj][ni], 2); + uint32_t h3 = converter.convert(acc[mi * MMAS_M_PER_LOOP + mj][ni], 3); + + // Store to shared memory. + fmha::sts(smem_write_ + row_0, h2); + fmha::sts(smem_write_ + row_1, h3); + } + } + + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. + if (ROW_PACKING == 4) { + smem_write_ ^= 16; + } else if (ROW_PACKING == 2) { + smem_write_ ^= 3 * 16; + } else if (ROW_PACKING == 1) { + // ย  ย  ย  ย  7 + // ย  ย  ย  / ย  ย \ + // ย  ย  ย 3 ย  ย  ย 3 + // ย  ย / ย \ ย  ย / ย \ + // ย  1 ย  ย 1 ย 1 ย  ย 1 + static_assert(Mma_tile::MMAS_N <= 64, ""); + if (Mma_tile::MMAS_N >= 32 && ni % 16 == 15) { + smem_write_ ^= 63 * 16; + } else if (Mma_tile::MMAS_N >= 16 && ni % 8 == 7) { + smem_write_ ^= 31 * 16; + } else if (Mma_tile::MMAS_N >= 8 && ni % 4 == 3) { + smem_write_ ^= 15 * 16; + } else if (Mma_tile::MMAS_N >= 4 && ni % 2 == 1) { + smem_write_ ^= 7 * 16; + } else if (Mma_tile::MMAS_N >= 2) { + smem_write_ ^= 3 * 16; + } + } else { + assert(false); + } + } + } + + // The write pointer. + uint32_t smem_write_; + // The write pointer. + uint32_t smem_read_; + // Is the thread active for the last LDS of the series? + int is_active_for_last_lds_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Turing_hmma_fp16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + // The FP16 accumulators. + using Accumulators_fp16 = fmha::Fragment_accumulator; + // The FP32 accumulators. + using Accumulators_fp32 = fmha::Fragment_accumulator; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store from FP16 accumulators. That's the default. + template + inline __device__ void store(Accumulators_fp16 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } + + // Store from FP32 accumulators. Special trick for the Flash-attention kernel. + // Convert from fp32 to fp16 before STS + template + inline __device__ void store(Accumulators_fp32 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Ampere_hmma_fp16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + // The FP16 accumulators. + using Accumulators_fp16 = fmha::Fragment_accumulator; + // The FP32 accumulators. + using Accumulators_fp32 = fmha::Fragment_accumulator; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store from FP16 accumulators. That's the default. + template + inline __device__ void store(Accumulators_fp16 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } + + // Store from FP32 accumulators. Special trick for the Flash-attention kernel. + // Convert from fp32 to fp16 before STS + template + inline __device__ void store(Accumulators_fp32 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Ampere_hmma_bf16_bf16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + + // The FP32 accumulators (only FP32 acc is supported for BF16 MMA). + using Accumulators_bf16 = fmha::Fragment_accumulator; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} + + // Store from FP32 accumulators. Special trick for the Flash-attention kernel. + // Convert from fp32 to bf16 before STS + template + inline __device__ void store(Accumulators_bf16 const (&acc)[M][N], int mi) { + this->template store_(acc, mi); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Hmma_smem_tile_o; + // The MMA tile. + using Mma_tile = typename Base::Mma_tile; + // The accumulators. + using Accumulator = typename Base::Accumulator; + + // The size of each + enum { BYTES_PER_ELEMENT = Base::BYTES_PER_ELEMENT }; + + // The size of each row in shared memory. + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW * Cta_tile::WARPS_K }; + + // The size of each row in shared memory. + enum { BYTES_PER_LDS = Base::BYTES_PER_LDS }; + + // The number of threads (to produce 16B per LDS). + enum { THREADS_PER_ROW = Base::THREADS_PER_ROW }; + + // The number of outer loops. + enum { LOOPS = Base::LOOPS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Base::ROWS_PER_LDS }; + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = Base::HAS_INCOMPLETE_LDS }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = Base::LDS_PER_LOOP }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Take the XOR pattern into account for the column. + read_col ^= (read_row & 0x7) * 2; + + // Assemble the read pointer. + this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; + } + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; + int is_valid = ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_; + if (!HAS_INCOMPLETE_LDS || is_valid) { + fmha::lds(tmp[jj], this->smem_read_ + imm); + } + } + + // Perform the reduction. + out[ii] = tmp[0]; +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + out[ii] = fmha::fadd4(out[ii], tmp[jj]); + } + } + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + // Store 1st column of the different MMAs. + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + // Pack vectors. + uint2 tmp0; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); + + uint2 tmp1; + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + } + + // Swizzle the write pointer using a XOR of 16B. + this->smem_write_ ^= 32; + + // Store 2nd column of the different MMAs. + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + } + + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. + static_assert(Mma_tile::MMAS_N <= 16, ""); + if (Mma_tile::MMAS_N >= 16 && (ni & 7) == 7) { + this->smem_write_ ^= 31 * 32; + } else if (Mma_tile::MMAS_N >= 8 && (ni & 3) == 3) { + this->smem_write_ ^= 15 * 32; + } else if (Mma_tile::MMAS_N >= 4 && (ni & 1) == 1) { + this->smem_write_ ^= 7 * 32; + } else if ((ni & 1) == 0) { + this->smem_write_ ^= 3 * 32; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Hmma_smem_tile_o { + // The traits class. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Hmma_smem_tile_o; + // The MMA tile. + using Mma_tile = typename Base::Mma_tile; + // The accumulators. + using Accumulator = typename Base::Accumulator; + + // The size of each element. + enum { BYTES_PER_ELEMENT = Base::BYTES_PER_ELEMENT }; + + // The size of each row in shared memory. + enum { BYTES_PER_ROW = Base::BYTES_PER_ROW * Cta_tile::WARPS_K }; + + // The size of each row in shared memory. + enum { BYTES_PER_LDS = Base::BYTES_PER_LDS }; + + // The number of threads (to produce 16B per LDS). + enum { THREADS_PER_ROW = Base::THREADS_PER_ROW }; + + // The number of outer loops. + enum { LOOPS = Base::LOOPS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Base::ROWS_PER_LDS }; + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = Base::HAS_INCOMPLETE_LDS }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = Base::LDS_PER_LOOP }; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Take the XOR pattern into account for the column. + read_col ^= (read_row & 0x7) * 2; + + // Assemble the read pointer. + this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; + } + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; + int is_valid = ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_; + if (!HAS_INCOMPLETE_LDS || is_valid) { + fmha::lds(tmp[jj], this->smem_read_ + imm); + } + } + + // Perform the reduction. + out[ii] = tmp[0]; +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + out[ii] = fmha::fadd4(out[ii], tmp[jj]); + } + } + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + + // Store 1st column of the different MMAs. + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + // Pack vectors. + uint2 tmp0; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); + + uint2 tmp1; + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + } + + // Swizzle the write pointer using a XOR of 16B. + this->smem_write_ ^= 32; + + // Store 2nd column of the different MMAs. + if (ni < Mma_tile::VALID_MMAS_N) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + // Precompute the immediates to jump between rows. + int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; + int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; + + uint2 tmp0, tmp1; + tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); + tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); + + tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); + tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); + + // Store. + fmha::sts(this->smem_write_ + row_0, tmp0); + fmha::sts(this->smem_write_ + row_1, tmp1); + } + } + + // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. + static_assert(Mma_tile::MMAS_N <= 16, ""); + if ((ni & 1) == 0) { + this->smem_write_ ^= 3 * 32; + } else if (Mma_tile::MMAS_N >= 16 && (ni & 7) == 7) { + this->smem_write_ ^= 31 * 32; + } else if (Mma_tile::MMAS_N >= 8 && (ni & 3) == 3) { + this->smem_write_ ^= 15 * 32; + } else if (Mma_tile::MMAS_N >= 4 && (ni & 1) == 1) { + this->smem_write_ ^= 7 * 32; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// I M M A +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// each thread holds 8 accumulator registers per 16x16 MMA, representing a 2x4 tile +template +struct Regs_to_rows { + template + static inline __device__ void extract(Acc const& acc, uint4& row0, uint4& row1) { + // Volta/Turing: row-major + uint32_t tmp_00 = acc.reg(0); + uint32_t tmp_01 = acc.reg(2); + uint32_t tmp_02 = acc.reg(1); + uint32_t tmp_03 = acc.reg(3); + uint32_t tmp_10 = acc.reg(4); + uint32_t tmp_11 = acc.reg(6); + uint32_t tmp_12 = acc.reg(5); + uint32_t tmp_13 = acc.reg(7); + + row0.x = tmp_00; + row0.y = tmp_01; + row0.z = tmp_02; + row0.w = tmp_03; + + row1.x = tmp_10; + row1.y = tmp_11; + row1.z = tmp_12; + row1.w = tmp_13; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Regs_to_rows_8bit { + template + static inline __device__ void extract(Acc const& acc, uint4& row0, uint4& row1) { + // Ampere: col-major + uint32_t tmp_00 = acc.reg(0); + uint32_t tmp_01 = acc.reg(4); + uint32_t tmp_02 = acc.reg(1); + uint32_t tmp_03 = acc.reg(5); + uint32_t tmp_10 = acc.reg(2); + uint32_t tmp_11 = acc.reg(6); + uint32_t tmp_12 = acc.reg(3); + uint32_t tmp_13 = acc.reg(7); + + row0.x = tmp_00; + row0.y = tmp_01; + row0.z = tmp_02; + row0.w = tmp_03; + + row1.x = tmp_10; + row1.y = tmp_11; + row1.z = tmp_12; + row1.w = tmp_13; + } +}; + +template <> +struct Regs_to_rows : public Regs_to_rows_8bit {}; + +template <> +struct Regs_to_rows : public Regs_to_rows_8bit {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Regs_to_rows { + template + static inline __device__ void extract(Acc const& acc, uint2& row0, uint2& row1) { + uint16_t* row0_ptr = reinterpret_cast(&row0); + uint16_t* row1_ptr = reinterpret_cast(&row1); + row0_ptr[0] = acc.u16(0); + row0_ptr[1] = acc.u16(4); + row0_ptr[2] = acc.u16(1); + row0_ptr[3] = acc.u16(5); + + row1_ptr[0] = acc.u16(2); + row1_ptr[1] = acc.u16(6); + row1_ptr[2] = acc.u16(3); + row1_ptr[3] = acc.u16(7); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void add4(uint4& dst, uint4 const& src) { + reinterpret_cast(dst.x) += reinterpret_cast(src.x); + reinterpret_cast(dst.y) += reinterpret_cast(src.y); + reinterpret_cast(dst.z) += reinterpret_cast(src.z); + reinterpret_cast(dst.w) += reinterpret_cast(src.w); +} + +template +inline __device__ void add_vec(uint4& dst, uint4 const& src) { + add4(dst, src); +} + +template <> +inline __device__ void add_vec(uint4& dst, uint4 const& src) { + dst = fmha::hadd8(dst, src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// The base class for 32-bit/16-bit accumulator types of imma/qmma. +// TODO Can we port Ampere hmma fp32 to this? +template +struct Smem_tile_o_base_8bit_mma { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // The size of each element. + enum { BYTES_PER_ELEMENT = sizeof(typename Traits::Accumulator_type) }; + + // The amount of bytes per row (without packing or split-k). + enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; + + // The size of each STS. + enum { BYTES_PER_STS = BYTES_PER_ELEMENT * 4 }; + + // The STS Packed Data Type + using Sts_packed_type = typename Uint_from_size_in_bytes::Type; + + // The size of each LDS. + enum { BYTES_PER_LDS = 16 }; + + // The number of threads to store a "row" of the matrix. We force it to 16 for SEQLEN=384. + enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; + + // The number of rows loaded per LDS. + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + // The STS bytes for one quad of threads + enum { BYTES_PER_STS_PER_QUAD = BYTES_PER_STS * 4 }; + + // The xor factor per LDS + // (4 consecutive threads do 64B swizzle for 16B per sts, 32B swizzle for 8B per sts) + enum { XOR_FACTOR = fmha::Div_up::VALUE }; + + // The smem offset in bytes per MMA_N (2 squad threads) + enum { BYTES_OFFSET_PER_MMA_N = BYTES_PER_STS * 8 }; + + // The number of "rows" to process in total. + enum { ROWS = Cta_tile::M }; + + // We want at least one output per thread (if possible). + enum { ROWS_PER_LOOP_ = ROWS <= 64 ? ROWS : (int)Min::VALUE }; + + // We also want to have "complete" MMAs. + enum { ROWS_PER_LOOP = Max::VALUE }; + + // The number of outer loops. + enum { LOOPS = fmha::Div_up::VALUE }; + + // Make sure it matches our expectations. + static_assert(ROWS_PER_LOOP >= (int)Mma_tile::M_PER_MMA_PER_CTA, ""); + + // Do we have to guard against partial writes/reads. + enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; + + // The total number of LDS per loop. + enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; + + // The amount of shared memory. + enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW * Cta_tile::WARPS_K }; + + // The amount of row packing to make sure we have at least 128B per smem row (without split-k). + enum { ROW_PACKING = Max<1, 128 / BYTES_PER_ROW>::VALUE }; + + // Make sure our row packing is correct + static_assert(ROWS_PER_LOOP % ROW_PACKING == 0, ""); + + // The amount of shared memory per row after packing. + enum { BYTES_PER_ROW_WITH_PACKING = BYTES_PER_ROW * ROW_PACKING }; + + // Make sure we have at least 128B per row after packing. + static_assert(BYTES_PER_ROW_WITH_PACKING >= 128, ""); + + // The number of threads per row after packing. + enum { THREADS_PER_ROW_WITH_PACKING = THREADS_PER_ROW * ROW_PACKING }; + + // Make sure we have at least 8 threads per row after packing. + static_assert(THREADS_PER_ROW_WITH_PACKING >= 8, ""); + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + static_assert(WARPS_K > 1 || std::is_same::value, + "Kernel misconfigured. No split-k needed."); + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_4x1x2 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // Ctor. + inline __device__ Smem_tile_o_base_8bit_mma(void* smem, int tidx) { + // Get a 32-bit value for the shared memory address. + uint32_t smem_ = __nvvm_get_smem_pointer(smem); + + // The row/col written by the thread. + int write_row, write_col; + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + write_row = (tidx & 0x20) / 4 + (tidx & 0x1e) / 8; + write_col = (tidx & 0x40) / 8 + (tidx & 0x07); + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + write_row = (tidx & 0x20) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x40) / 8 + (tidx & 0x07); + + // SEQLEN == 128 and HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x20) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x40) / 4 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and HIDDEN_SIZE_PER_HEAD == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + write_row = (tidx & 0x18) / 8; + write_col = (tidx & 0xe0) / 4 + (tidx & 0x07); + + // SEQLEN == 256, 384, 512 and HIDDEN_SIZE_PER_HEAD == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) / 4 + (tidx & 0x07); + + // SEQLEN == 256, 384 and HIDDEN_SIZE_PER_HEAD == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0xe0) / 2 + (tidx & 0x07); + + // GMMA: HIDDEN_SIZE_PER_HEAD == 64. + } else if (WARPS_4x1x2 && Cta_tile::N == 64) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x80) / 8 + (tidx & 0x07); + + // Ada e4m3_fp32 + } else if (WARPS_4x1x1) { + write_row = (tidx & 0x60) / 2 + (tidx & 0x1c) / 4; + write_col = (tidx & 0x80) / 8 + (tidx & 0x07); + + // Not supported. + } else { + assert(false); + } + + // Assemble the write pointer. + smem_write_ = smem_ + write_row * BYTES_PER_ROW_WITH_PACKING * Cta_tile::WARPS_K + + write_col * BYTES_PER_STS; + + // The element read by each thread. + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + + // Is that thread active on the last LDS? + if (HAS_INCOMPLETE_LDS) { + is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < ROWS_PER_LOOP; + } + + // The XOR params. + constexpr int XOR_MOD = 2 / ROW_PACKING; + + // Take the XOR pattern and the packing into account for the column. + read_col += read_row % ROW_PACKING * XOR_FACTOR; + read_row /= ROW_PACKING; + read_col ^= (read_row % XOR_MOD) * XOR_FACTOR; + + // Assemble the read pointer. + smem_read_ = smem_ + read_row * BYTES_PER_ROW_WITH_PACKING * Cta_tile::WARPS_K + + read_col * BYTES_PER_LDS; + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + // Note: ROWS_PER_LDS does not take packing into account - hence BYTES_PER_ROW. + int imm = + ii * ROWS_PER_LDS * BYTES_PER_ROW * Cta_tile::WARPS_K + jj * BYTES_PER_ROW_WITH_PACKING; + + // Load... + if (!HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || is_active_for_last_lds_)) { + fmha::lds(tmp[jj], smem_read_ + imm); + } + } + +// Perform the reduction. +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + add_vec(tmp[0], tmp[jj]); + } + + // Write to out. + out[ii] = tmp[0]; + } + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { + enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { +#pragma unroll + for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) { + Sts_packed_type row_0, row_1; + Regs_to_rows::extract(acc[mi * MMAS_M_PER_LOOP + mj][ni], row_0, row_1); + + /* + (32bit acc) Each thread of a quad writes 16B per STS -> 64B per store. + Account for 2 -> 128B. + (16bit acc) Each thread of a quad writes 8B per STS -> 32B per store. + Account for 2 -> 64B. + */ + int imm_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW * Cta_tile::WARPS_K + + (ni / 2) * BYTES_OFFSET_PER_MMA_N; + int imm_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW * Cta_tile::WARPS_K + + (ni / 2) * BYTES_OFFSET_PER_MMA_N; + + // Store the elements. + fmha::sts(this->smem_write_ + imm_0, row_0); + fmha::sts(this->smem_write_ + imm_1, row_1); + } + // (32bit acc) Each thread of a quad writes 16B per STS -> 64B per store. + // (16bit acc) Each thread of a quad writes 8B per STS -> 32B per store. + if (Mma_tile::MMAS_N == 1) { + // Noop. + } else if (Mma_tile::MMAS_N % 2 == 0) { + this->smem_write_ ^= BYTES_PER_STS_PER_QUAD; + } else { + assert(false && "Unsupported"); + } + } + } + + // The write pointer. + uint32_t smem_write_; + // The write pointer. + uint32_t smem_read_; + // Is the thread active for the last LDS of the series? + int is_active_for_last_lds_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Volta_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o + : public Smem_tile_o_base_8bit_mma { + // The traits class. + using Traits = fmha::Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Smem_tile_o_base_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_o(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_o_interleaved { + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + enum { VEC = 32 }; + + enum { NUM_SLICES = Cta_tile::N / VEC }; + + static_assert(NUM_SLICES == 1 || NUM_SLICES == 2, ""); + + enum { BYTES_PER_ELEMENT = 4 }; + + enum { BYTES_PER_STS = 16 }; + + enum { BYTES_PER_LDS = 16 }; + + enum { ELTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT }; + + static_assert(VEC * BYTES_PER_ELEMENT == 128, ""); + + enum { BYTES_PER_ROW = Cta_tile::WARPS_K * VEC * BYTES_PER_ELEMENT }; + + // Each row only stores one slice. The other slice starts this many rows below + enum { ROWS_PER_SLICE = Cta_tile::WARPS_M * 16 }; + + enum { TOTAL_ROWS = NUM_SLICES * ROWS_PER_SLICE }; + + enum { BYTES_PER_TILE = BYTES_PER_ROW * TOTAL_ROWS }; + + // LDS + enum { THREADS_PER_ROW = 8 }; + + enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; + + enum { LDS_PER_LOOP = TOTAL_ROWS / ROWS_PER_LDS }; + + // Ctor. + inline __device__ Smem_tile_o_interleaved(void* smem, int tidx) { + smem_ = __nvvm_get_smem_pointer(smem); + + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + + // Warp order (fastest to slowest): m => n => k + // 2x2: 2,2,1 then 2,1,2: mask_m = 0x20, mask_k = 0x40, div_m = 32, div_k = 64 + // 1x4: 1,4,1 then 1,1,4: mask_m = 0x00, mask_k = 0x60, div_m = X, div_k = 32 + // 1x8: 1,8,1 then 1,1,8: mask_m = 0x00, mask_k = 0xe0, div_m = X, div_k = 32 + static_assert(WARPS_N == 1, ""); + + // A thread holds 4 elts of 4B. One slice of 32 elts has 128B. + // Two MMAs in N constitute one slice + + // the slice offset that depends on ni and has to be added later + static_assert(VEC / ELTS_PER_STS == 8, ""); // 8 columns of 4 elements + if (WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2) { + write_row = (tidx & 0x1c) / 4 + (tidx & 0x20) / 2; // warp_m * 16 rows + write_col = (tidx & 0x03) + (tidx & 0x40) / 8; // warp_k * VEC / ELTS_PER_STS + } else { + write_row = (tidx & 0x1c) / 4; + write_col = (tidx & 0x03) + (tidx & 0xe0) / 4; // warp_k * VEC / ELTS_PER_STS + } + write_col ^= (write_row & 0x01) * 4; // left or right 64B + + // this->smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; + + int read_row = tidx / THREADS_PER_ROW; + int read_col = tidx % THREADS_PER_ROW; + read_col ^= (read_row & 0x01) * 4; + this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Store the accumulators. + template + inline __device__ void store(Accumulator const (&acc)[M][N], int mi) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + int const slice = ni / NUM_SLICES; + int col = write_col ^ ((ni & 1) * 4); + + uint32_t smem_write_ = smem_ + write_row * BYTES_PER_ROW + col * BYTES_PER_STS; + + // Extract the elements. + uint4 row_0, row_1; + + Regs_to_rows::extract(acc[mi][ni], row_0, row_1); + + // Each thread of a quad writes 16B per STS -> 64B per store. Account for + // 2 -> 128B. + int imm_0 = (slice * ROWS_PER_SLICE + 0) * BYTES_PER_ROW; + int imm_1 = (slice * ROWS_PER_SLICE + 8) * BYTES_PER_ROW; + + // Store the elements. + fmha::sts(smem_write_ + imm_0, row_0); + fmha::sts(smem_write_ + imm_1, row_1); + } + } + + // Load the output fragments. + inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { +#pragma unroll + for (int ii = 0; ii < LDS_PER_LOOP; ++ii) { + // Load the elements before the reduction (split-K). + uint4 tmp[Cta_tile::WARPS_K]; +#pragma unroll + for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) { + int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * VEC * BYTES_PER_ELEMENT; + fmha::lds(tmp[jj], smem_read_ + imm); + } + +// Perform the reduction. +#pragma unroll + for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) { + add4(tmp[0], tmp[jj]); + } + + // Write to out. + out[ii] = tmp[0]; + } + } + + int write_row; + int write_col; + uint32_t smem_write_; + uint32_t smem_read_; + uint32_t smem_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/smem_tile_qkv.h b/csrc/fmha_v2/fmha/smem_tile_qkv.h new file mode 100644 index 0000000000..32caaadb3a --- /dev/null +++ b/csrc/fmha_v2/fmha/smem_tile_qkv.h @@ -0,0 +1,592 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qkv_interleaved + : public fmha::Smem_tile_without_skews { + // The traits class. + using Traits = Traits_; + // The base class. + using Base = fmha::Smem_tile_without_skews; + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The fragment. + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + enum { ROWS_PER_WARP = Cta_tile::THREADS_PER_WARP / Base::THREADS_PER_ROW }; + + using Fragment_a = fmha::Fragment_a; + using Fragment_b = fmha::Fragment_b; + + inline __device__ Smem_tile_qkv_interleaved(char* smem, int tidx) : Base(smem, tidx) {} + + uint32_t offset; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a_base : public Smem_tile_qkv_interleaved { + using Base = Smem_tile_qkv_interleaved; + + static_assert(Base::THREADS_PER_ROW == 128 / 16, ""); + + enum { SMEM_ROWS_PER_WARP = Base::ROWS_PER_WARP }; + + static_assert(SMEM_ROWS_PER_WARP == 4, ""); + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment_a; + + inline __device__ Smem_tile_qk_interleaved_a_base(char* smem, int tidx) : Base(smem, tidx) { + static_assert(Cta_tile::WARPS_K == 1, ""); + static_assert(Cta_tile::WARPS_M == 1 || Cta_tile::WARPS_M == 2, ""); + static_assert(Cta_tile::WARPS_N == 2 || Cta_tile::WARPS_N == 4, ""); + + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + + constexpr int WARP_MASK_M = fmha::Warp_masks::M; + constexpr int WARP_DIV_M = 1 * 1 * Cta_tile::THREADS_PER_WARP; + + int const warp_m = (tidx & WARP_MASK_M) / WARP_DIV_M; + + /* Read address layout for ldsm: + * [ 0 16 1 17 2 18 3 19] + * [20 4 21 5 22 6 23 7] + * [ 8 24 9 25 10 26 11 27] + * [28 12 29 13 30 14 31 15] + */ + int read_row = (tidx & 0x04) / 4 + (tidx & 0x08) / 4 + warp_m * SMEM_ROWS_PER_WARP; + int read_col = (tidx & 0x03) * 2 + (tidx & 0x10) / 16; + read_col ^= (read_row & 0x01); + + this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a + : public Smem_tile_qk_interleaved_a_base { + using Traits = fmha::Volta_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_a_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_a(char* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_M], int ki) { + int slice = ki / 2; + +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; mi++) { + // the data for the second slice sits below the first slice + uint32_t read_ptr = this->smem_ + this->offset + slice * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data; + ldsm_with_lds( + data, read_ptr + mi * Cta_tile::WARPS_M * Base::SMEM_ROWS_PER_WARP * Base::BYTES_PER_ROW); + static_assert(Fragment::NUM_REGS == 2, ""); + frag[mi].reg(0) = data.x; + frag[mi].reg(1) = data.y; + } + + this->offset ^= 16; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a + : public Smem_tile_qk_interleaved_a_base { + using Traits = fmha::Turing_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_a_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_a(char* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_M], int ki) { + int slice = ki / 2; + +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; mi++) { + // the data for the second slice sits below the first slice + uint32_t read_ptr = this->smem_ + this->offset + slice * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data; + fmha::ldsm( + data, read_ptr + mi * Cta_tile::WARPS_M * Base::SMEM_ROWS_PER_WARP * Base::BYTES_PER_ROW); + static_assert(Fragment::NUM_REGS == 2, ""); + frag[mi].reg(0) = data.x; + frag[mi].reg(1) = data.y; + } + + this->offset ^= 16; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_a + : public Smem_tile_qk_interleaved_a_base { + using Traits = fmha::Ampere_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_a_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_a(char* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_M], int ki) { +#pragma unroll + for (int mi = 0; mi < Mma_tile::MMAS_M; mi++) { + // the data for the second slice sits below the first slice + uint32_t read_ptr = this->smem_ + this->offset + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint4 data; + fmha::ldsm( + data, read_ptr + mi * Cta_tile::WARPS_M * Base::SMEM_ROWS_PER_WARP * Base::BYTES_PER_ROW); + static_assert(Fragment ::NUM_REGS == 4, ""); + frag[mi].reg(0) = data.x; + frag[mi].reg(1) = data.y; + frag[mi].reg(2) = data.z; + frag[mi].reg(3) = data.w; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b_base : public Smem_tile_qkv_interleaved { + using Base = Smem_tile_qkv_interleaved; + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment_b; + + inline __device__ Smem_tile_qk_interleaved_b_base(char* smem, int tidx) : Base(smem, tidx) { + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + + // 2x2: 2,2,1 then 2,1,2 + // 1x4: 1,4,1 then 1,1,4 + static_assert(WARPS_K == 1, ""); + + constexpr int WARP_MASK_N = fmha::Warp_masks::N; + constexpr int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + + // Only need to care about warp_n, because if warps_m > 1, both of them should load + // the same data + int const warp = (tidx & WARP_MASK_N) / WARP_DIV_N; + + /* transpose the order of the LDSMs: first along K, then along N + * [ 0 8 1 9 2 10 3 11] + * [12 4 13 5 14 6 15 7] + * [16 24 17 25 18 26 19 27] + * [28 20 29 21 30 22 31 23] + */ + int read_row = (tidx & 0x04) / 4 + (tidx & 0x10) / 8 + warp * Base::ROWS_PER_WARP; + int read_col = (tidx & 0x03) * 2 + (tidx & 0x08) / 8; + read_col ^= (read_row & 0x01); + + this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b + : public Smem_tile_qk_interleaved_b_base { + using Traits = fmha::Volta_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_b_base; + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_b(char* smem, int tidx) : Base(smem, tidx) { + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + constexpr int WARP_MASK_N = fmha::Warp_masks::N; + constexpr int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + + // Only need to care about warp_n, because if warps_m > 1, both of them should load + // the same data + int const warp = (tidx & WARP_MASK_N) / WARP_DIV_N; + + int read_row = (tidx & 0x04) / 4 + (tidx & 0x08) / 4 + warp * Base::ROWS_PER_WARP; + int read_col = (tidx & 0x03) * 2 + (tidx & 0x10) / 16; + read_col ^= (read_row & 0x01); + + this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + int slice = ki / 2; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t read_ptr = this->smem_ + this->offset + slice * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data; + ldsm_with_lds(data, + read_ptr + ni * Base::ROWS_PER_WARP * Cta_tile::WARPS_N * Base::BYTES_PER_ROW); + static_assert(Fragment ::NUM_REGS == 2, ""); + frag[ni].reg(0) = data.x; + frag[ni].reg(1) = data.y; + } + this->offset ^= 16; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b + : public Smem_tile_qk_interleaved_b_base { + using Traits = fmha::Turing_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_b_base; + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_b(char* smem, int tidx) : Base(smem, tidx) { + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + constexpr int WARP_MASK_N = fmha::Warp_masks::N; + constexpr int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; + + // Only need to care about warp_n, because if warps_m > 1, both of them should load + // the same data + int const warp = (tidx & WARP_MASK_N) / WARP_DIV_N; + + int read_row = (tidx & 0x04) / 4 + (tidx & 0x08) / 4 + warp * Base::ROWS_PER_WARP; + int read_col = (tidx & 0x03) * 2 + (tidx & 0x10) / 16; + read_col ^= (read_row & 0x01); + + this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + int slice = ki / 2; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t read_ptr = this->smem_ + this->offset + slice * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data; + fmha::ldsm(data, + read_ptr + ni * Base::ROWS_PER_WARP * Cta_tile::WARPS_N * Base::BYTES_PER_ROW); + static_assert(Fragment ::NUM_REGS == 2, ""); + frag[ni].reg(0) = data.x; + frag[ni].reg(1) = data.y; + } + this->offset ^= 16; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_qk_interleaved_b + : public Smem_tile_qk_interleaved_b_base { + using Traits = fmha::Ampere_imma_int8_int32_traits; + using Base = Smem_tile_qk_interleaved_b_base; + + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment; + + inline __device__ Smem_tile_qk_interleaved_b(char* smem, int tidx) : Base(smem, tidx) {} + + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t read_ptr = this->smem_ + this->offset + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint4 data; + fmha::ldsm(data, + read_ptr + ni * Base::ROWS_PER_WARP * Cta_tile::WARPS_N * Base::BYTES_PER_ROW); + static_assert(Fragment ::NUM_REGS == 4, ""); + frag[ni].reg(0) = data.x; + frag[ni].reg(1) = data.y; + frag[ni].reg(2) = data.z; + frag[ni].reg(3) = data.w; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b_base + : public Smem_tile_qkv_interleaved { + using Base = Smem_tile_qkv_interleaved; + + using Mma_tile = typename Base::Mma_tile; + // TODO Row or col? + using Fragment = typename Base::Fragment_b; + + inline __device__ Smem_tile_v_interleaved_b_base(char* smem, int tidx) : Base(smem, tidx) { + // // DEBUG. + // static_assert( Cta_tile::N == 64, "" ); + // // END OF DEBUG. + + constexpr int WARPS_M = Cta_tile::WARPS_M; + constexpr int WARPS_N = Cta_tile::WARPS_N; + constexpr int WARPS_K = Cta_tile::WARPS_K; + + // 2x2: 2,2,1 then 2,1,2 + // 1x4: 1,4,1 then 1,1,4 + static_assert(WARPS_N == 1, ""); + + // Don't need to consider WARP M. For two warps in M, both would read the same tile + constexpr int WARP_MASK_K = fmha::Warp_masks::K; + constexpr int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; + + // the static assert above ensures, that only warp_m or warp_k is non-zero + int const warp = (tidx & WARP_MASK_K) / WARP_DIV_K; + + /* LDSM.T addresses: warps are split in two to match BMM1-GEMM-N (= BMM2-GEMM-K) register + * layout + * <== GEMM-N = D = 64 ==> + * [ 0: 0 0 1 0 2 0 3 0] WARP 0 + * [ 1: 0 4 0 5 0 6 0 7] + * [ 2: 8 0 9 0 10 0 11 0] + * [ 3: 0 12 0 13 0 14 0 15] + * [ 4: 0 0 0 0 0 0 0 0] WARP 1 + * [ 5: 0 0 0 0 0 0 0 0] + * [ 6: 0 0 0 0 0 0 0 0] + * [ 7: 0 0 0 0 0 0 0 0] + * [ 8: 0 0 0 0 0 0 0 0] WARP 2 + * [ 9: 0 0 0 0 0 0 0 0] + * [10: 0 0 0 0 0 0 0 0] + * [11: 0 0 0 0 0 0 0 0] + * [12: 0 0 0 0 0 0 0 0] WARP 3 + * [13: 0 0 0 0 0 0 0 0] + * [14: 0 0 0 0 0 0 0 0] + * [15: 0 0 0 0 0 0 0 0] + * [16: 16 0 17 0 18 0 19 0] WARP 0 + * [17: 0 20 0 21 0 22 0 23] + * [18: 24 0 25 0 26 0 27 0] + * [19: 0 28 0 29 0 30 0 31] + * etc ... + */ + + // TODO this is a bit misleading, as 4 rows per warp applies to the + // row-major tiles above. In this smem tile, a warp actually owns 8 rows in + // SMEM, but we have 4 rows per slice + + // TODO would be good to rename to SMEM_ROWS_PER_WARP to make this clearer + static_assert(Base::ROWS_PER_WARP == 4, ""); + + read_row = ((tidx & 0x0f) / 4) + warp * Base::ROWS_PER_WARP; + read_col = (tidx & 0x03) * 2; + read_col ^= (read_row & 0x01); + + // this->offset = read_row * Base::BYTES_PER_ROW + read_col * Base::BYTES_PER_LDS; + } + + int read_row; + int read_col; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b + : public Smem_tile_v_interleaved_b_base { + using Traits = fmha::Volta_imma_int8_int32_traits; + using Base = Smem_tile_v_interleaved_b_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment_b; + + // Ctor. + inline __device__ Smem_tile_v_interleaved_b(char* smem, int tidx) : Base(smem, tidx) {} + + // Load fragments from shared memory. + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + // static_assert(Mma_tile::MMAS_K == 4, ""); + static_assert(Mma_tile::MMAS_N == 4, ""); + static_assert(Base::ROWS_PER_WARP == 4, ""); + // static_assert(Cta_tile::WARPS_K == 2, ""); + + int offset_k = ki * Cta_tile::WARPS_K * Base::ROWS_PER_WARP; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t offset = (this->read_row + offset_k) * Base::BYTES_PER_ROW + + (this->read_col ^ (ni & 1)) * Base::BYTES_PER_LDS; + + // for the next 32B in N, we have to jump down K rows, so K / 4 rows in + // smem, which stores 4 canonical 32B rows per 128B + offset += (ni / 2) * Cta_tile::K / 4 * Base::BYTES_PER_ROW; + uint32_t read_ptr = this->smem_ + offset; // + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data = {0, 0}; + ldsmt_with_lds(data, read_ptr); + static_assert(Fragment ::NUM_REGS == 2, ""); + swizzle_rows(frag[ni].reg(0), frag[ni].reg(1), data.x, data.y); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b + : public Smem_tile_v_interleaved_b_base { + using Traits = fmha::Turing_imma_int8_int32_traits; + using Base = Smem_tile_v_interleaved_b_base; + using Mma_tile = typename Base::Mma_tile; + using Fragment = typename Base::Fragment_b; + + // Ctor. + inline __device__ Smem_tile_v_interleaved_b(char* smem, int tidx) : Base(smem, tidx) {} + + // Load fragments from shared memory. + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + static_assert(Mma_tile::MMAS_N == 4, ""); + static_assert(Base::ROWS_PER_WARP == 4, ""); + + int offset_k = ki * Cta_tile::WARPS_K * Base::ROWS_PER_WARP; +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t offset = (this->read_row + offset_k) * Base::BYTES_PER_ROW + + (this->read_col ^ (ni & 1)) * Base::BYTES_PER_LDS; + // for the next 32B in N, we have to jump down K rows, so K / 4 rows in + // smem, which stores 4 canonical 32B rows per 128B + offset += (ni / 2) * Cta_tile::K / 4 * Base::BYTES_PER_ROW; + uint32_t read_ptr = this->smem_ + offset; // + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data = {0, 0}; + fmha::ldsmt(data, read_ptr); + static_assert(Fragment ::NUM_REGS == 2, ""); + swizzle_rows(frag[ni].reg(0), frag[ni].reg(1), data.x, data.y); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_interleaved_b + : public Smem_tile_v_interleaved_b_base { + // The instruction traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Smem_tile_v_interleaved_b_base; + // The tile of MMAs. + using Mma_tile = typename Base::Mma_tile; + // The fragment loaded. + using Fragment = typename Base::Fragment_b; + + // Ctor. + inline __device__ Smem_tile_v_interleaved_b(char* smem, int tidx) : Base(smem, tidx) {} + + // Load from shared memory. + inline __device__ void load(Fragment (&frag)[Mma_tile::MMAS_N], int ki) { + int offset_k = ki * Cta_tile::WARPS_K * Base::ROWS_PER_WARP * 2; + static_assert(Cta_tile::K != 192 || Mma_tile::MMAS_K == 2, ""); +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ni++) { + uint32_t offset = (this->read_row + offset_k) * Base::BYTES_PER_ROW + + (this->read_col ^ (ni & 1)) * Base::BYTES_PER_LDS; + + // For the next 32B in N, we have to jump down K rows, so K / 4 rows in smem, which + // stores 4 canonical 32B rows per 128B. + offset += (ni / 2) * Cta_tile::K / 4 * Base::BYTES_PER_ROW; + uint32_t read_ptr = this->smem_ + offset; // + ki * Base::ROWS * Base::BYTES_PER_ROW / 2; + uint2 data0 = {0, 0}; + uint2 data1 = {0, 0}; + fmha::ldsmt(data0, read_ptr); + + if (Cta_tile::K != 192 || ki == 0) { + static_assert(Cta_tile::K != 192 || Mma_tile::MMAS_K == 2); + // For 192, with 4 warps, we need 128 rows of K, so for the second ldsm, we need + // only 2x instead of 4x. + int imm = Cta_tile::WARPS_K * Base::ROWS_PER_WARP * Base::BYTES_PER_ROW; + fmha::ldsmt(data1, read_ptr + imm); + } + + static_assert(Fragment ::NUM_REGS == 4, ""); + swizzle_rows(frag[ni].reg(0), frag[ni].reg(2), data0.x, data0.y); + swizzle_rows(frag[ni].reg(1), frag[ni].reg(3), data1.x, data1.y); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/smem_tile_v.h b/csrc/fmha_v2/fmha/smem_tile_v.h new file mode 100644 index 0000000000..67a02f37ca --- /dev/null +++ b/csrc/fmha_v2/fmha/smem_tile_v.h @@ -0,0 +1,1008 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template class Rows_per_xor_pattern, + int BUFFERS_PER_TILE = 1> +struct Smem_tile_v_hmma { + using Base = Smem_tile_without_skews::VALUE, 1>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_hmma::Base { + // The traits class. + using Traits = fmha::Volta_hmma_fp16_16x16x16_traits; + // The base class. + using Base = typename Smem_tile_v_hmma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x08) / 8; + read_col = (tidx & 0x10) / 16 + (tidx & 0x03) * 2; + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x08) / 4 + (tidx & 0x02) / 2; + read_col = (tidx & 0x10) / 16 + (tidx & 0x01) * 4; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 4 + (tidx & 0x08) / 2 + (tidx & 0x03); + read_col = (tidx & 0x10) / 16; + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x08) / 8; + read_col = (tidx & 0x10) / 16 + (tidx & 0x03) * 2; + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x08) / 4 + (tidx & 0x02) / 2; + read_col = (tidx & 0x10) / 16 + (tidx & 0x01) * 4; + + // SEQLEN == 256, 384 and 512 and N == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && + (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0xe0) / 2 + (tidx & 0x08) / 2 + (tidx & 0x03); + read_col = (tidx & 0x10) / 16; + + // ANY SEQLEN and N == 16. + } else if (WARPS_4x1x1 && Cta_tile::N == 16) { + read_row = (tidx & 0x08) / 8; + read_col = (tidx & 0x10) / 16 + (tidx & 0x03) * 2; + + // ANY SEQLEN and N == 32. + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + read_row = (tidx & 0x08) / 4 + (tidx & 0x02) / 2; + read_col = (tidx & 0x10) / 16 + (tidx & 0x01) * 4; + + // ANY SEQLEN and N == 64/128/256. + } else if (WARPS_4x1x1 && (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0x08) / 2 + (tidx & 0x03); + read_col = (tidx & 0x10) / 16; + + // Not supported! + } else { + assert(false); + } + + // Apply the XOR for the column. + read_col ^= read_row % Base::ROWS_PER_XOR_PATTERN; + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The column offset. + int offset = this->smem_read_offset_ ^ (ni * 2 * BYTES_PER_LDS); + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // The rows. + int row_0 = ki * 16 * Cta_tile::WARPS_K + 0; + int row_1 = ki * 16 * Cta_tile::WARPS_K + 8; + + // Load the data using 2x LDS.128. + uint4 tmp; + fmha::lds(tmp, this->smem_ + offset + row_0 * Base::BYTES_PER_ROW_BEFORE_PACKING); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + + fmha::lds(tmp, this->smem_ + offset + row_1 * Base::BYTES_PER_ROW_BEFORE_PACKING); + b[ni].reg(4) = tmp.x; + b[ni].reg(5) = tmp.y; + b[ni].reg(6) = tmp.z; + b[ni].reg(7) = tmp.w; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_turing_hmma + : public Smem_tile_v_hmma::Base { + // The base class. + using Base = typename Smem_tile_v_hmma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v_turing_hmma(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x04) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x06) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 4 + (tidx & 0x07); + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x04) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x06) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 384, 512 and N == 64, 128, 256. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && + (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0xe0) / 2 + (tidx & 0x07); + read_col = (tidx & 0x07); + + // ANY SEQLEN and N == 16. + } else if (WARPS_4x1x1 && Cta_tile::N == 16) { + read_row = (tidx & 0x04) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // ANY SEQLEN and N == 32. + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + read_row = (tidx & 0x06) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // ANY SEQLEN and N == 64/128/256. + } else if ((WARPS_4x1x1) && (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0x07); + read_col = (tidx & 0x07); + + // Not supported! + } else { + assert(false); + } + + // The 2nd HMMA. + read_col ^= (tidx & 0x08) / 8; + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // For even values of k value we jump by 16*WARPS_K rows and for odd, we jump by 8 rows. + int row = (ki / 2) * 16 * Cta_tile::WARPS_K / ROW_PACKING + (ki % 2) * 8 / ROW_PACKING; + + // Load the data using LDSM.MT88.2. + uint2 tmp; + fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + } + + // Move to the next N position. + if (Mma_tile::MMAS_N == 1) { + ; + } else if (Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); + } else if (Mma_tile::MMAS_N == 8) { + this->smem_read_offset_ ^= BYTES_PER_LDS * ((ni & 1) == 0 ? 2 : ((ni & 3) == 3 ? 14 : 6)); + } else if (Mma_tile::MMAS_N == 16) { + this->smem_read_offset_ ^= BYTES_PER_LDS * ((ni & 1) == 0 ? 2 + : ((ni & 7) == 7) ? 30 + : (((ni & 3) == 3) ? 14 : 6)); + } else { + assert(false); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_turing_hmma { + // The base class. + using Base = Smem_tile_v_turing_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_turing_hmma { + // The base class. + using Base = Smem_tile_v_turing_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template class Rows_per_xor_pattern, + int BUFFERS_PER_TILE = 1> +struct Smem_tile_v_imma { + using Base = Smem_tile_without_skews::VALUE, 1>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_imma::Base { + // The traits class. + using Traits = Volta_imma_int8_int32_traits; + // The base class. + using Base = typename Smem_tile_v_imma::Base; + + // DEBUG. + static_assert(Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING == 2, ""); + // END OF DEBUG. + + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) { + // The row/col read by the thread. + int read_row, read_col; + + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 32 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 16 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 384, 512 and N == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // Not supported. + } else { + assert(false); + } + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { + static_assert(Mma_tile::MMAS_K == 2 || Mma_tile::MMAS_K == 3 || Mma_tile::MMAS_K == 4 || + Mma_tile::MMAS_K == 6, + ""); +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // Jump by 8*16 rows per K but account for packing. + int row = ki * 16 * Cta_tile::WARPS_K / ROW_PACKING; + + // We emulate the Turing logic, which loads the data using LDSM.MT88.2: + // uint2 tmp; + // fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); + // this call fetches two 8x16 matrices, stacked on top of each other + + // we fake LDSM.MT88.2, with 2 LDS.128 and a shuffle: + // - T 0 - T 7 have the smem addresses of LDSM 0, each should do 16B loads + // - T 8 - T15 have the smem addresses of LSDM 1, each should do 16B loads + int const lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + + uint4 tmp16{0, 0, 0, 0}; // 16B + + if (lane < 16) { + fmha::lds(tmp16, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); + } + + uint16_t* tmp16c = reinterpret_cast(&tmp16); // 8x2B: we move pairs + + uint2 tmp; // 2*4B + uint16_t* t = reinterpret_cast(&tmp); // 4x2B + + int const src_col = lane / 4; // 0 - 7 + int const src_row = lane % 4 * 2; + +// We have to shuffle the values to distribute them in the warp. +#pragma unroll + for (int it = 0; it < 8; it++) { + uint16_t val, x, y; + val = tmp16c[it]; + x = __shfl_sync(uint32_t(-1), val, src_row + 0); + __syncwarp(); + y = __shfl_sync(uint32_t(-1), val, src_row + 1); + __syncwarp(); + + if (src_col == it) { + t[0] = x; + t[1] = y; + } + val = tmp16c[it]; + x = __shfl_sync(uint32_t(-1), val, src_row + 8); + __syncwarp(); + y = __shfl_sync(uint32_t(-1), val, src_row + 9); + __syncwarp(); + + if (src_col == it) { + t[2] = x; + t[3] = y; + } + } + + // Repack the elements. With LDSM.T, thread 0 has the following elements in its two + // regs: + // + // R0 = [(n=0 k=0), (n=1 k=0), (n=0 k=8), (n=1 k=8)] + // R1 = [(n=0 k=1), (n=1 k=1), (n=0 k=9), (n=1 k=9)] + // + // We want to repack the values as: + // + // R0 = [(n=0 k=0), (n=0 k=1), (n=0 k=8), (n=0 k=9)] + // R1 = [(n=1 k=0), (n=1 k=1), (n=1 k=8), (n=1 k=9)] + // + // Since that this layout corresponds to the layout of elements in the Fragment_a from + // P. + + swizzle_rows(b[ni].reg(0), b[ni].reg(1), tmp.x, tmp.y); + } + + // Move to the next N position. + if (Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 1 : 3); + } else { + assert(false); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_imma::Base { + // The traits class. + using Traits = Turing_imma_int8_int32_traits; + // The base class. + using Base = typename Smem_tile_v_imma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 32. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 32 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 16 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 384, 512 and N == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // Not supported. + } else { + assert(false); + } + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { + static_assert(Mma_tile::MMAS_K == 2 || Mma_tile::MMAS_K == 3 || Mma_tile::MMAS_K == 4 || + Mma_tile::MMAS_K == 6 || Mma_tile::MMAS_K == 8, + ""); +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // Jump by 8*16 rows per K but account for packing. + int row = ki * 16 * Cta_tile::WARPS_K / ROW_PACKING; + + // Load the data using LDSM.MT88.2. + uint2 tmp; + fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); + + // Repack the elements. With LDSM.T, thread 0 has the following elements in its two + // regs: + // + // R0 = [(n=0 k=0), (n=1 k=0), (n=0 k=8), (n=1 k=8)] + // R1 = [(n=0 k=1), (n=1 k=1), (n=0 k=9), (n=1 k=9)] + // + // We want to repack the values as: + // + // R0 = [(n=0 k=0), (n=0 k=1), (n=0 k=8), (n=0 k=9)] + // R1 = [(n=1 k=0), (n=1 k=1), (n=1 k=8), (n=1 k=9)] + // + // Since that this layout corresponds to the layout of elements in the Fragment_a from + // P. + + swizzle_rows(b[ni].reg(0), b[ni].reg(1), tmp.x, tmp.y); + + // b[ni].reg(0) = tmp.x; + // b[ni].reg(1)= tmp.y; + } + + // Move to the next N position. + if (Mma_tile::MMAS_N == 1) { + // Noop. + } else if (Mma_tile::MMAS_N == 2) { + this->smem_read_offset_ ^= BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N == 4) { + this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 1 : 3); + } else { + assert(false); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v_ampere_hmma + : public Smem_tile_v_hmma::Base { + // The base class. + using Base = typename Smem_tile_v_hmma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v_ampere_hmma(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + // Flash Attention uses WARPS_4x1x1 + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 128 and N == 64/128/256. + } else if (WARPS_2x1x2 && (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0x40) / 4 + (tidx & 0x0f); + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 384, 512 and N == 64/128/256. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && + (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256)) { + read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); + read_col = (tidx & 0x07); + + // ANY SEQLEN and N == 16. + } else if (WARPS_4x1x1 && Cta_tile::N == 16) { + read_row = (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // ANY SEQLEN and N == 32. + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + read_row = (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // ANY SEQLEN and N == 64/128/256. + } else if (WARPS_4x1x1 && (Cta_tile::N == 64 || Cta_tile::N == 128 || Cta_tile::N == 256 || + Cta_tile::N == 512)) { + read_row = (tidx & 0x0f); + read_col = (tidx & 0x07); + + // Not supported. + } else { + assert(false); + } + + // The 2nd HMMA. + read_col ^= (tidx & 0x10) / 16; + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // Jump by 16 * #warps row. Account for the packing. + int row = ki * 16 * Cta_tile::WARPS_K / ROW_PACKING; + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // Jump by 16 * #warps row. Account for the packing. + int row = ki * 16 * Cta_tile::WARPS_K / ROW_PACKING; + + // Load the data using LDSM.MT88.2. + uint4 tmp; + fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + + row * Base::BYTES_PER_ROW); + b[ni].reg(0) = tmp.x; + b[ni].reg(1) = tmp.y; + b[ni].reg(2) = tmp.z; + b[ni].reg(3) = tmp.w; + } + + // Move the pointer for the next ni. I expect the compiler to not recompute those. + static_assert(Mma_tile::MMAS_N <= 64, ""); + if (Mma_tile::MMAS_N >= 32 && ni % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 16 && ni % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 8 && ni % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 4 && ni % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; + } else if (Mma_tile::MMAS_N >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_hmma { + // The base class. + using Base = Smem_tile_v_ampere_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_hmma { + // The base class. + using Base = Smem_tile_v_ampere_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_hmma { + // The base class. + using Base = Smem_tile_v_ampere_hmma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +template +struct Smem_tile_v_ampere_8bit_mma + : public Smem_tile_v_imma::Base { + // The base class. + using Base = typename Smem_tile_v_imma::Base; + // The MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + // The fragment. + using Fragment = fmha::Fragment_b; + + // The size of a single LDS in bytes. + enum { BYTES_PER_LDS = 16 }; + + // Ctor. + inline __device__ Smem_tile_v_ampere_8bit_mma(void* smem, int tidx) : Base(smem, tidx) { + // Warps. + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + enum { WARPS_K = Cta_tile::WARPS_K }; + + // Determine the config. + enum { WARPS_2x1x2 = WARPS_M == 2 && WARPS_N == 1 && WARPS_K == 2 }; + + enum { WARPS_1x1x8 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 8 }; + + enum { WARPS_1x1x4 = WARPS_M == 1 && WARPS_N == 1 && WARPS_K == 4 }; + + enum { WARPS_4x1x1 = WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1 }; + + // The row/col read by the thread. + int read_row, read_col; + + // SEQLEN == 128 and N == 16. + if (WARPS_2x1x2 && Cta_tile::N == 16) { + read_row = (tidx & 0x40) / 32 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 128 and N == 32. + } else if (WARPS_2x1x2 && Cta_tile::N == 32) { + read_row = (tidx & 0x40) / 16 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 128 and N == 64. + } else if (WARPS_2x1x2 && Cta_tile::N == 64) { + read_row = (tidx & 0x40) / 8 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + + // SEQLEN == 256, 512 and N == 16. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 16) { + read_row = (tidx & 0xe0) / 16 + (tidx & 0x08) / 8; + read_col = (tidx & 0x07); + + // SEQLEN == 256, 512 and N == 32. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 32) { + read_row = (tidx & 0xe0) / 8 + (tidx & 0x0c) / 4; + read_col = (tidx & 0x03) * 2 + (tidx & 0x04) / 4; + + // SEQLEN == 256, 384, 512 and N == 64. + } else if ((WARPS_1x1x8 || WARPS_1x1x4) && Cta_tile::N == 64) { + read_row = (tidx & 0xe0) / 4 + (tidx & 0x0e) / 2; + read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; + } else if (WARPS_4x1x1 && Cta_tile::N == 32) { + read_row = (tidx % 32) / 4; + read_col = read_row % 2 + (tidx % 4) * 2; + } else if (WARPS_4x1x1 && Cta_tile::N == 64) { + read_row = (tidx % 32) / 2; + read_col = read_row % 4 + (tidx & 0x01) * 4; + } else if (WARPS_4x1x1 && (Cta_tile::N >= 128)) { + read_row = tidx % 32; + read_col = tidx % 8; + + // Not supported. + } else { + assert(false); + } + + // The shared memory offset. + this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + // Load from shared memory. + inline __device__ void load(Fragment (&b)[Mma_tile::VALID_MMAS_N], int ki) { +// static_assert(Mma_tile::MMAS_K == 3 || Mma_tile::MMAS_K == 2 || Mma_tile::MMAS_K == 1, ""); +#pragma unroll + for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) { + // The amount of row packing. + enum { ROW_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING }; + + // // Make sure we do not end up with weird values :) + // static_assert(Cta_tile::WARPS_K % ROW_PACKING == 0, ""); + + // Skip N paddings + if (ni < Mma_tile::VALID_MMAS_N) { + // Jump by 8*32 rows per K but account for the fact that we have packing. + int row_0 = (ki * 32 + 0 * 16) * Cta_tile::WARPS_K / ROW_PACKING; + int row_1 = (ki * 32 + 1 * 16) * Cta_tile::WARPS_K / ROW_PACKING; + + // Load the data using LDSM.MT88.2. + uint32_t smem = this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_; + uint2 tmp_0; + fmha::ldsmt(tmp_0, smem + row_0 * Base::BYTES_PER_ROW); + + // Load the next two values. + uint2 tmp_1 = make_uint2(0u, 0u); + if constexpr (Cta_tile::K > 16) { + fmha::ldsmt(tmp_1, smem + row_1 * Base::BYTES_PER_ROW); + } + + // Repack the elements. With LDSM.T, thread 0 has the following elements in its 4 regs: + // + // R0 = [(n=0 k= 0), (n=1 k= 0), (n=0 k= 1), (n=1 k= 1)] + // R1 = [(n=0 k= 8), (n=1 k= 8), (n=0 k= 9), (n=1 k= 9)] + // R2 = [(n=0 k=128), (n=1 k=128), (n=0 k=129), (n=1 k=129)] + // R3 = [(n=0 k=136), (n=1 k=136), (n=0 k=137), (n=1 k=137)] + // + // We want to repack the values as: + // + // R0 = [(n=0 k= 0), (n=0 k= 1), (n=0 k= 8), (n=0 k= 9)] + // R1 = [(n=0 k=128), (n=0 k=129), (n=0 k=136), (n=0 k=137)] + // R2 = [(n=1 k= 0), (n=1 k= 1), (n=1 k= 8), (n=1 k= 9)] + // R3 = [(n=1 k=128), (n=1 k=129), (n=1 k=136), (n=1 k=137)] + // + // Since this layout corresponds to the layout of elements in the Fragment_a from P. + + swizzle_rows(b[ni].reg(0), b[ni].reg(2), tmp_0.x, tmp_0.y); + swizzle_rows(b[ni].reg(1), b[ni].reg(3), tmp_1.x, tmp_1.y); + } + + // Move to the next N position. + if (Mma_tile::MMAS_N >= 32 && ni % 16 == 15) { + this->smem_read_offset_ ^= 31 * BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N >= 16 && ni % 8 == 7) { + this->smem_read_offset_ ^= 15 * BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N >= 8 && ni % 4 == 3) { + this->smem_read_offset_ ^= 7 * BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N >= 4 && ni % 2 == 1) { + this->smem_read_offset_ ^= 3 * BYTES_PER_LDS; + } else if (Mma_tile::MMAS_N >= 2) { + this->smem_read_offset_ ^= 1 * BYTES_PER_LDS; + } else { + assert(false); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_8bit_mma { + // The base class. + using Base = + Smem_tile_v_ampere_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_8bit_mma { + // The base class. + using Base = Smem_tile_v_ampere_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_v + : public Smem_tile_v_ampere_8bit_mma { + // The base class. + using Base = Smem_tile_v_ampere_8bit_mma; + + // Ctor. + inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/softmax.h b/csrc/fmha_v2/fmha/softmax.h new file mode 100644 index 0000000000..68ecea49b9 --- /dev/null +++ b/csrc/fmha_v2/fmha/softmax.h @@ -0,0 +1,3964 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +#include "fmha/fragment.h" +#include "fmha/utils.h" + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Sum_ { + enum { IS_SUM = 1 }; + + static inline __device__ float apply(float x, float y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Max_ { + enum { IS_SUM = 0 }; + + static inline __device__ float apply(float x, float y) { return fmaxf(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float apply_exp_(float x, float max) { + return isinf(x) ? 0.f : __expf(x - max); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ float apply_exp_<2>(float x, float max) { + return __expf(x - max); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float get_alibi_head_scaling_factor(int const in_head_id, + AlibiParams const& params) { + int const head_id = params.head_idx_offset + in_head_id; + if (head_id < params.h_pow_2) { + // 2^(head_id * -8 / h) + return exp2f((head_id + 1) * 2 * params.alibi_neg4_div_h) * params.scale_after_alibi; + } else { + // 1,3,5... etc + float const adjusted_head_id = 2 * (head_id - params.h_pow_2) + 1; + // 2^(adjusted_head_id * -4 / h) + return exp2f(adjusted_head_id * params.alibi_neg4_div_h) * params.scale_after_alibi; + ; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ReadType { + using T = float; +}; + +template <> +struct ReadType<4> { + using T = float; +}; + +template <> +struct ReadType<8> { + using T = float2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_reduce { + // Helper class to distribute MMA tiles reduced over rows per warp over quads. + + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + static constexpr int ROWS = WARPS_M * MMAS_M * 16; + static constexpr int COLS = WARPS_N; + static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; + static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); + static constexpr int ELTS_PER_TILE = ROWS * COLS; + + static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW; + static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP; + static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; + + using read_t = typename ReadType::T; + + __device__ inline Smem_tile_reduce(float* smem_, int const tidx) { + int lane = tidx % 32; + int warp = tidx / 32; + + int warp_m = warp % WARPS_M; + int warp_n = warp / WARPS_M; + + qid_ = lane % 4; + int qp = lane / 4; + + // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. + // This won't affect reading as we assume commutative reduction ops. + int const col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); + smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; + smem_read_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; + } + + __device__ inline void store(float (&frag)[2 * MMAS_M]) { + if (qid_ == 0) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + int offset = mi * 16 * WARPS_N; + smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; + smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; + } + } + } + + __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; mi++) { + int offset = mi * 16 * 4; + frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; + frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; + } + } + + int qid_; + float* smem_write_; + read_t* smem_read_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_base { + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // The number of groups of warp such that we have at most 4 warps writing consecutive elements. + enum { GROUPS = fmha::Div_up::VALUE }; + + // The number of elements that we are going to store per row. + enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS }; + + // The number of rows. + enum { ROWS = Cta_tile::M * GROUPS }; + + // The total number of elements. + enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW }; + + // If shared memory is used + enum { USE_SHARED_MEMORY = Cta_tile::WARPS_N > 1 }; + + // DEBUG. + static_assert(ELEMENTS == Cta_tile::M * Cta_tile::WARPS_N, ""); + + // END OF DEBUG. + + // The number of rows per thread. + enum { ROWS_PER_THREAD = MMAS_M * 2 }; + + // Ctor. + template + inline __device__ Softmax_base(Params const& params, void* smem, int bidb, int tidx) + : smem_(reinterpret_cast(smem)), tidx_(tidx) { + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Decompose the warp index into M and N. + int warp_m = warp % Cta_tile::WARPS_M; + int warp_n = warp / Cta_tile::WARPS_M; + + // Decompose the warp-n index into group/position-inside-the-group. + int warp_g = warp_n / ELEMENTS_PER_ROW; + int warp_i = warp_n % ELEMENTS_PER_ROW; + + // The location written by the threads. + int write_row = warp_g * Cta_tile::M + warp_m * Mma_tile::M_PER_MMA + lane / 4; + int write_col = warp_i; + + // Assemble the write pointer. + smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; + + // Assemble the read pointer. + smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; + } + + // Apply mask before softmax. Use 1 byte per MMA distributed as 2x4. + template + inline __device__ void apply_mask(Mask const& mask) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + if (!mask.is_valid(mi, ni, ii, jj)) { + elt_[2 * mi + ii][4 * ni + jj] = -FLT_MAX; + } + } + } + } + } + } + + template + inline __device__ void apply_mask_alibi(Mask const& mask, int head_id, + AlibiParams const& alibi_params) { + // 'if constexpr' because ALiBi is only defined for causal masks + if constexpr (Kernel_traits::CAUSAL_MASK) { + float m = get_alibi_head_scaling_factor(head_id, alibi_params); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + int row, col; + mask.get_row_col(row, col, mi, ni, ii, jj); + if (mask.is_valid(row, col)) { + // Since softmax is shift invariant, + // it is sufficient just to use the column as the multiplier + elt_[2 * mi + ii][4 * ni + jj] = + elt_[2 * mi + ii][4 * ni + jj] * alibi_params.scale_after_alibi + + m * (col + alibi_params.sequence_pos_offset); + } else { + elt_[2 * mi + ii][4 * ni + jj] = -FLT_MAX; + } + } + } + } + } + } else { + __builtin_unreachable(); + } + } + + // Apply the mask to unpacked data. + inline __device__ void apply_mask(uint32_t const (&packed_mask)[MMAS_M]) { + // This code works only if we have MMAS_N <= 4. + static_assert(MMAS_N <= 4, ""); + + // Expand the mask. + int mask[MMAS_M * 2][MMAS_N * 4]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + mask[2 * mi + 0][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 0)); + mask[2 * mi + 0][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 1)); + mask[2 * mi + 1][4 * ni + 0] = packed_mask[mi] & (1u << (8 * ni + 2)); + mask[2 * mi + 1][4 * ni + 1] = packed_mask[mi] & (1u << (8 * ni + 3)); + mask[2 * mi + 0][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 4)); + mask[2 * mi + 0][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 5)); + mask[2 * mi + 1][4 * ni + 2] = packed_mask[mi] & (1u << (8 * ni + 6)); + mask[2 * mi + 1][4 * ni + 3] = packed_mask[mi] & (1u << (8 * ni + 7)); + } + } + +// Apply the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + if (!mask[mi][ni]) { + elt_[mi][ni] = -FLT_MAX; + } + } + } + } + + // Mask the elements that are outside the the sequence length. + inline __device__ void apply_mask(int const actual_seqlen) { + // The warp/lane decomposition. + int const warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int const lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + + // The warp in the n dimension. + int const warp_n = warp / Cta_tile::WARPS_M; + // The position within a quad. + int const quad_lane = lane % 4; + +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Determine the position in the sequence. + int const offset = ni * Mma_tile::N_PER_MMA_PER_CTA + warp_n * 16; + if (offset + 0 + 2 * quad_lane >= actual_seqlen) { + elt_[mi][4 * ni + 0] = -FLT_MAX; // 0 + } + if (offset + 1 + 2 * quad_lane >= actual_seqlen) { + elt_[mi][4 * ni + 1] = -FLT_MAX; // 1 + } + if (offset + 8 + 2 * quad_lane >= actual_seqlen) { + elt_[mi][4 * ni + 2] = -FLT_MAX; // 8 + } + if (offset + 9 + 2 * quad_lane >= actual_seqlen) { + elt_[mi][4 * ni + 3] = -FLT_MAX; // 9 + } + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const max) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] = apply_exp_(elt_[mi][ni], max); + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_scale_exp(float const (&max)[MMAS_M * 2], float scale_bmm1) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] = apply_exp_(scale_bmm1 * elt_[mi][ni], max[mi]); + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); + } + } + } + + // Do a warp-wide reduction. + template + inline __device__ void reduce_Nx1(float (&dst)[MMAS_M * 2]) { +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float tmp[2] = {0.f, 0.f}; +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + tmp[0] += elt_[mi][4 * ni + 0] + elt_[mi][4 * ni + 1]; + tmp[1] += elt_[mi][4 * ni + 2] + elt_[mi][4 * ni + 3]; + } + dst[mi] = tmp[0] + tmp[1]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 4; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + } + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_2x2() { + float dst[MMAS_M * 2]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + // Pair-wise adds in the different threads of the reference code (x+y and z+w). + float a_01 = elt_[mi][0] + elt_[mi][1]; + float a_45 = elt_[mi][4] + elt_[mi][5]; + + //// tmp[0/1] += __shfl_xor(2) in the reference code. + a_01 += elt_[mi][2] + elt_[mi][3]; + a_45 += elt_[mi][6] + elt_[mi][7]; + + //// tmp[0/1] += __shfl_xor(8) in the reference code. + a_01 += a_45; + + if (MMAS_N >= 3) { + float a_89 = elt_[mi][8] + elt_[mi][9]; + a_89 += elt_[mi][10] + elt_[mi][11]; + if (MMAS_N == 4) { + float a_cd = elt_[mi][12] + elt_[mi][13]; + a_cd += elt_[mi][14] + elt_[mi][15]; + a_89 += a_cd; + } + a_01 += a_89; + } + dst[mi] = a_01; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 4; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + } + +// Store the different values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 4 == 0) { + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0]; + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 2 values (one for each warp). + float2 tmp = reinterpret_cast(smem_)[tidx_]; + + // Compute the reduction of those 2 values in a binary-tree fashion. + return Functor::apply(tmp.x, tmp.y); + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_1x4() { + float dst[MMAS_M * 2]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float tmp[2] = {0.f, 0.f}; +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + tmp[0] += elt_[mi][4 * ni + 0] + elt_[mi][4 * ni + 1]; + tmp[1] += elt_[mi][4 * ni + 2] + elt_[mi][4 * ni + 3]; + } + dst[mi] = tmp[0] + tmp[1]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 4; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + } + +// Store the different values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 4 == 0) { + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0]; + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the + // float4. + float4 tmp[1]; + if (tidx_ < Cta_tile::M) { + tmp[0] = reinterpret_cast(&smem_[0 * ELEMENTS / 2])[tidx_]; + } + + // Compute the reduction of those 8 values in a binary-tree fashion. + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y); + tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w); + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z); + + // Return the final reduction. + return tmp[0].x; + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_1x8() { + float dst[MMAS_M * 2]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { + // Apply the summation inside the thread. + float tmp[MMAS_M * 2][2]; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + tmp[mi][0] = 0.f; + tmp[mi][1] = 0.f; +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + tmp[mi][0] += elt_[mi][4 * ni + 0]; + tmp[mi][0] += elt_[mi][4 * ni + 1]; + tmp[mi][1] += elt_[mi][4 * ni + 2]; + tmp[mi][1] += elt_[mi][4 * ni + 3]; + } + dst[mi] = tmp[mi][0] + tmp[mi][1]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 4; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + } + +// Store the different values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 4 == 0) { + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0]; + smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the + // float4. + float4 tmp[2]; + if (tidx_ < Cta_tile::M) { + tmp[0] = reinterpret_cast(&smem_[0 * ELEMENTS / 2])[tidx_]; + tmp[1] = reinterpret_cast(&smem_[1 * ELEMENTS / 2])[tidx_]; + } + + // Compute the reduction of those 8 values in a binary-tree fashion. + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y); + tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w); + tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y); + tmp[1].z = Functor::apply(tmp[1].z, tmp[1].w); + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z); + tmp[1].x = Functor::apply(tmp[1].x, tmp[1].z); + tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x); + + // Return the result. + return tmp[0].x; + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_() { + // The result of the reduction. Threads 0..Cta_tile::M-1 own a single row value. + float red = 0.f; + + // SEQLEN == 128. + if (Cta_tile::WARPS_M == 2 && Cta_tile::WARPS_N == 2) { + red = reduce_2x2(); + + // SEQLEN == 256. + } else if (Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4) { + red = reduce_1x4(); + + // SEQLEN == 384. + } else if (Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8) { + red = reduce_1x8(); + + // Not supported. + } else { + assert(false); + } + + return red; + } + + // Finalize the reduction. + inline __device__ void shuffle(float (&dst)[MMAS_M * 2], float red) { + // Store the value back to shared memory. + if (tidx_ < Cta_tile::M) { + smem_[tidx_] = red; + } + + // Make sure the data is in shared memory. + __syncthreads(); + +// Finally read the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0]; + dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce(float (&dst)[MMAS_M * 2]) { + // NOTE: 1 warp along reduce direction, no syncs + if (Cta_tile::WARPS_N == 1) { + reduce_Nx1(dst); + } else { + // The result of the reduction. Threads 0..Cta_tile::M-1 own a single row value. + float red = reduce_(); + + // Make sure we can write to shared memory. + __syncthreads(); + + // Finalize the reduction. + shuffle(dst, red); + } + } + + // Scale all the elements. + inline __device__ void scale(float const (&sum)[MMAS_M * 2]) { + // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. + float inv_sum[MMAS_M * 2]; +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + } + +// Update the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + elt_[mi][ni] *= inv_sum[mi]; + } + } + } + + // Shared memory for the CTA-wide reduction. + float *smem_, *smem_write_, *smem_read_; + // The current thread index. + int tidx_; + // The elements. + float elt_[MMAS_M * 2][MMAS_N * 4]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_hmma : public Softmax_base { + // The base class. + using Base = Softmax_base; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // Whether we need to skip the softmax due to the sliding-window attention + // Otherwise, we will get NANs as those tokens are all masked out. + enum { SLIDING_WINDOW_ATTENTION = Kernel_traits::SLIDING_WINDOW_ATTENTION }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax_hmma(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1), + params_softcapping_scale_bmm1_(params.softcapping_scale_bmm1) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // The elements. + float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3]; + float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3]; + + // Transform to accumulators. + acc[mi][ni].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + acc[mi][ni].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + acc[mi][ni].reg(2) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + acc[mi][ni].reg(3) = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Normalize the values, and clamp to finite half. + uint32_t acc_0 = satfinite_h2(hmul2(acc[mi][ni].reg(0), params_scale_bmm1_)); + uint32_t acc_1 = satfinite_h2(hmul2(acc[mi][ni].reg(1), params_scale_bmm1_)); + uint32_t acc_2 = satfinite_h2(hmul2(acc[mi][ni].reg(2), params_scale_bmm1_)); + uint32_t acc_3 = satfinite_h2(hmul2(acc[mi][ni].reg(3), params_scale_bmm1_)); + + // Extract the values as floats. + half2_to_float2(this->elt_[2 * mi + 0][4 * ni + 0], this->elt_[2 * mi + 0][4 * ni + 1], + acc_0); + half2_to_float2(this->elt_[2 * mi + 1][4 * ni + 0], this->elt_[2 * mi + 1][4 * ni + 1], + acc_1); + half2_to_float2(this->elt_[2 * mi + 0][4 * ni + 2], this->elt_[2 * mi + 0][4 * ni + 3], + acc_2); + half2_to_float2(this->elt_[2 * mi + 1][4 * ni + 2], this->elt_[2 * mi + 1][4 * ni + 3], + acc_3); + + // Attention logit softcapping scale. + // 1.0f / softcapping_scale has been fused to scale_bmm1. + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + this->elt_[2 * mi + 0][4 * ni + 0] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 0]); + this->elt_[2 * mi + 0][4 * ni + 1] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 1]); + this->elt_[2 * mi + 1][4 * ni + 0] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 0]); + this->elt_[2 * mi + 1][4 * ni + 1] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 1]); + this->elt_[2 * mi + 0][4 * ni + 2] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 2]); + this->elt_[2 * mi + 0][4 * ni + 3] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 3]); + this->elt_[2 * mi + 1][4 * ni + 2] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 2]); + this->elt_[2 * mi + 1][4 * ni + 3] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 3]); + } + } + } + } + + // Apply the exp to all the elements. + // Need to make sure the results are zero when all elts are -FLT_MAX + // as it is possible that all tokens are masked out. + template + inline __device__ void apply_exp_with_mask(float const (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float max_val = APPLY_MASK && max[mi] == -FLT_MAX ? 0.f : max[mi]; +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + this->elt_[mi][ni] = expf(this->elt_[mi][ni] - max_val); + } + } + } + + // The scaling factor. + uint32_t const params_scale_bmm1_; + float const params_softcapping_scale_bmm1_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Fragment_helper {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_helper { + // The traits. + using Traits = fmha::Volta_imma_int8_int32_traits; + // The fragment A. + using Fragment_a = fmha::Fragment_a; + // The accumulator. + using Accumulator = fmha::Fragment_accumulator; + + // Load a 2x4 array from registers. + static inline __device__ void load(int32_t (&dst)[2][4], Accumulator const& src) { + dst[0][0] = src.elt(0); + dst[0][1] = src.elt(1); + dst[0][2] = src.elt(2); + dst[0][3] = src.elt(3); + dst[1][0] = src.elt(4); + dst[1][1] = src.elt(5); + dst[1][2] = src.elt(6); + dst[1][3] = src.elt(7); + } + + // Store to an accumulator. + static inline __device__ void store(Accumulator& dst, uint32_t const (&src)[2][4]) { + dst.reg(0) = src[0][0]; + dst.reg(1) = src[0][1]; + dst.reg(2) = src[0][2]; + dst.reg(3) = src[0][3]; + dst.reg(4) = src[1][0]; + dst.reg(5) = src[1][1]; + dst.reg(6) = src[1][2]; + dst.reg(7) = src[1][3]; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_helper { + // The traits. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The fragment A. + using Fragment_a = fmha::Fragment_a; + // The accumulator. + using Accumulator = fmha::Fragment_accumulator; + + // Load a 2x4 array from registers. + static inline __device__ void load(int32_t (&dst)[2][4], Accumulator const& src) { + dst[0][0] = src.elt(0); + dst[0][1] = src.elt(1); + dst[0][2] = src.elt(2); + dst[0][3] = src.elt(3); + dst[1][0] = src.elt(4); + dst[1][1] = src.elt(5); + dst[1][2] = src.elt(6); + dst[1][3] = src.elt(7); + } + + // Store to an accumulator. + static inline __device__ void store(Accumulator& dst, uint32_t const (&src)[2][4]) { + dst.reg(0) = src[0][0]; + dst.reg(1) = src[0][1]; + dst.reg(2) = src[0][2]; + dst.reg(3) = src[0][3]; + dst.reg(4) = src[1][0]; + dst.reg(5) = src[1][1]; + dst.reg(6) = src[1][2]; + dst.reg(7) = src[1][3]; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Fragment_helper { + // The traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The fragment A. + using Fragment_a = fmha::Fragment_a; + // The accumulator. + using Accumulator = fmha::Fragment_accumulator; + + // Load a 2x4 array from registers. + static inline __device__ void load(int32_t (&dst)[2][4], Accumulator const& src) { + dst[0][0] = src.elt(0); + dst[0][1] = src.elt(1); + dst[0][2] = src.elt(4); + dst[0][3] = src.elt(5); + dst[1][0] = src.elt(2); + dst[1][1] = src.elt(3); + dst[1][2] = src.elt(6); + dst[1][3] = src.elt(7); + } + + // Store to an accumulator. + static inline __device__ void store(Accumulator& dst, uint32_t const (&src)[2][4]) { + dst.reg(0) = src[0][0]; + dst.reg(1) = src[0][1]; + dst.reg(4) = src[0][2]; + dst.reg(5) = src[0][3]; + dst.reg(2) = src[1][0]; + dst.reg(3) = src[1][1]; + dst.reg(6) = src[1][2]; + dst.reg(7) = src[1][3]; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_imma : public Softmax_base { + // The base class. + using Base = Softmax_base; + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // The dst type + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax_imma(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1), + params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Scale the FP32 elements. + uint32_t tmp[2][4]; +#pragma unroll + for (int mj = 0; mj < 2; ++mj) { +#pragma unroll + for (int nj = 0; nj < 4; ++nj) { + float f = this->elt_[2 * mi + mj][4 * ni + nj] * scale; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(tmp[mj][nj]) : "f"(f)); + } + } + + // Convert to int8 and store. + Fragment_helper::store(acc[mi][ni], tmp); + } + } + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Convert from accumulators to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scale = reinterpret_cast(params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Load the values from the accumulator's registers. + int32_t tmp[2][4]; + Fragment_helper::load(tmp, acc[mi][ni]); + +// Convert to FP32 and scale. +#pragma unroll + for (int mj = 0; mj < 2; ++mj) { +#pragma unroll + for (int nj = 0; nj < 4; ++nj) { +#if defined(USE_I2F_EMULATION_TRICK) + float f = reinterpret_cast(tmp[mj][nj]); + this->elt_[2 * mi + mj][4 * ni + nj] = (f - FP32_I2F_MAGIC_NUMBER) * scale; +#else + this->elt_[2 * mi + mj][4 * ni + nj] = static_cast(tmp[mj][nj]) * scale; +#endif // defined(USE_I2F_EMULATION_TRICK) + } + } + } + } + } + + // Repack. We could use store/load to match the Smem_tile API. (shared by Ampere IMMA and Ada + // QMMA) + template + inline __device__ void pack(Fragment_a_ (&dst)[K][M]) { + // We pack N 16x16 acc tiles into K 16x32 tiles for A. + // In the 16x16 tile, a thread owns 4 elts per row (4 regs). + // In the 16x32 A tile, a thread owns 8 elts per row (2 regs). + // Hence we have to pack with a 2:1 ratio. + // For N = 1, K is 1: pack 4 values into dst reg 0. Set reg 1 to 0. + // For N = 2, K is 1: pack 8 values into dst regs 0, 1. + // For N = 3, K is 2: pack 12 values into dst regs (0,0), (0,1), (1,0). Set (1,1) to 0. + // For N = 4, K is 2: pack 16 values into dst regs (0,0), (0,1), (1,0), (1,1) + // For N = 5, K is 3: pack 20 values into dst regs (0,0), (0,1), (1,0), (1,1), (2,0). Set (2,1) + // to 0. For N = 6, K is 3: pack 24 values into dst regs (0,0), (0,1), (1,0), (1,1), (2,0), + // (2,1) + + static_assert(K == 3 || K == 2 || K == 1, ""); + + float const scale = reinterpret_cast(this->params_scale_softmax_); + +#pragma unroll + for (int mi = 0; mi < M; ++mi) { + // 1st row - 12 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][0] * scale; + float tmp_01 = this->elt_[2 * mi + 0][1] * scale; + float tmp_02 = this->elt_[2 * mi + 0][2] * scale; + float tmp_03 = this->elt_[2 * mi + 0][3] * scale; + float tmp_04 = this->elt_[2 * mi + 0][4] * scale; + float tmp_05 = this->elt_[2 * mi + 0][5] * scale; + float tmp_06 = this->elt_[2 * mi + 0][6] * scale; + float tmp_07 = this->elt_[2 * mi + 0][7] * scale; + float tmp_08 = this->elt_[2 * mi + 0][8] * scale; + float tmp_09 = this->elt_[2 * mi + 0][9] * scale; + float tmp_0a = this->elt_[2 * mi + 0][10] * scale; + float tmp_0b = this->elt_[2 * mi + 0][11] * scale; + + // 2nd row - 12 elements per row. + float tmp_20 = this->elt_[2 * mi + 1][0] * scale; + float tmp_21 = this->elt_[2 * mi + 1][1] * scale; + float tmp_22 = this->elt_[2 * mi + 1][2] * scale; + float tmp_23 = this->elt_[2 * mi + 1][3] * scale; + float tmp_24 = this->elt_[2 * mi + 1][4] * scale; + float tmp_25 = this->elt_[2 * mi + 1][5] * scale; + float tmp_26 = this->elt_[2 * mi + 1][6] * scale; + float tmp_27 = this->elt_[2 * mi + 1][7] * scale; + float tmp_28 = this->elt_[2 * mi + 1][8] * scale; + float tmp_29 = this->elt_[2 * mi + 1][9] * scale; + float tmp_2a = this->elt_[2 * mi + 1][10] * scale; + float tmp_2b = this->elt_[2 * mi + 1][11] * scale; + + // Pack the first 12 elements to 6 registers of 2 fragments. + dst[0][mi].reg(0) = fmha::float4_to_8bitx4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[0][mi].reg(1) = fmha::float4_to_8bitx4(tmp_20, tmp_21, tmp_22, tmp_23); + dst[0][mi].reg(2) = fmha::float4_to_8bitx4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[0][mi].reg(3) = fmha::float4_to_8bitx4(tmp_24, tmp_25, tmp_26, tmp_27); + if (K > 1) { + dst[1][mi].reg(0) = fmha::float4_to_8bitx4(tmp_08, tmp_09, tmp_0a, tmp_0b); + dst[1][mi].reg(1) = fmha::float4_to_8bitx4(tmp_28, tmp_29, tmp_2a, tmp_2b); + } + + if (Mma_tile::MMAS_N == 6) { + float tmp_0c = this->elt_[2 * mi + 0][12] * scale; + float tmp_0d = this->elt_[2 * mi + 0][13] * scale; + float tmp_0e = this->elt_[2 * mi + 0][14] * scale; + float tmp_0f = this->elt_[2 * mi + 0][15] * scale; + float tmp_10 = this->elt_[2 * mi + 0][16] * scale; + float tmp_11 = this->elt_[2 * mi + 0][17] * scale; + float tmp_12 = this->elt_[2 * mi + 0][18] * scale; + float tmp_13 = this->elt_[2 * mi + 0][19] * scale; + float tmp_14 = this->elt_[2 * mi + 0][20] * scale; + float tmp_15 = this->elt_[2 * mi + 0][21] * scale; + float tmp_16 = this->elt_[2 * mi + 0][22] * scale; + float tmp_17 = this->elt_[2 * mi + 0][23] * scale; + + float tmp_2c = this->elt_[2 * mi + 1][12] * scale; + float tmp_2d = this->elt_[2 * mi + 1][13] * scale; + float tmp_2e = this->elt_[2 * mi + 1][14] * scale; + float tmp_2f = this->elt_[2 * mi + 1][15] * scale; + float tmp_30 = this->elt_[2 * mi + 1][16] * scale; + float tmp_31 = this->elt_[2 * mi + 1][17] * scale; + float tmp_32 = this->elt_[2 * mi + 1][18] * scale; + float tmp_33 = this->elt_[2 * mi + 1][19] * scale; + float tmp_34 = this->elt_[2 * mi + 1][20] * scale; + float tmp_35 = this->elt_[2 * mi + 1][21] * scale; + float tmp_36 = this->elt_[2 * mi + 1][22] * scale; + float tmp_37 = this->elt_[2 * mi + 1][23] * scale; + + dst[1][mi].reg(2) = fmha::float4_to_8bitx4(tmp_0c, tmp_0d, tmp_0e, tmp_0f); + dst[1][mi].reg(3) = fmha::float4_to_8bitx4(tmp_2c, tmp_2d, tmp_2e, tmp_2f); + dst[2][mi].reg(0) = fmha::float4_to_8bitx4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[2][mi].reg(1) = fmha::float4_to_8bitx4(tmp_30, tmp_31, tmp_32, tmp_33); + dst[2][mi].reg(2) = fmha::float4_to_8bitx4(tmp_14, tmp_15, tmp_16, tmp_17); + dst[2][mi].reg(3) = fmha::float4_to_8bitx4(tmp_34, tmp_35, tmp_36, tmp_37); + } else if (Mma_tile::MMAS_N == 4) { + // SEQLEN == 128. + float tmp_0c = this->elt_[2 * mi + 0][12] * scale; + float tmp_0d = this->elt_[2 * mi + 0][13] * scale; + float tmp_0e = this->elt_[2 * mi + 0][14] * scale; + float tmp_0f = this->elt_[2 * mi + 0][15] * scale; + + float tmp_1c = this->elt_[2 * mi + 1][12] * scale; + float tmp_1d = this->elt_[2 * mi + 1][13] * scale; + float tmp_1e = this->elt_[2 * mi + 1][14] * scale; + float tmp_1f = this->elt_[2 * mi + 1][15] * scale; + + dst[1][mi].reg(2) = fmha::float4_to_8bitx4(tmp_0c, tmp_0d, tmp_0e, tmp_0f); + dst[1][mi].reg(3) = fmha::float4_to_8bitx4(tmp_1c, tmp_1d, tmp_1e, tmp_1f); + + // SEQLEN == 384 or SEQLEN == 256. + } else if (Mma_tile::MMAS_N == 3 || Mma_tile::MMAS_N == 2) { + // TODO added second OR term for ampere imma s=256: correct? + dst[1][mi].reg(2) = 0u; + dst[1][mi].reg(3) = 0u; + } else if (Mma_tile::MMAS_N == 1) { + dst[0][mi].reg(2) = 0u; + dst[0][mi].reg(3) = 0u; + + // Not implemented. + } else { + assert(false); + } + } + } + + // The scaling factors. + uint32_t const params_scale_bmm1_, params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_qmma : public Softmax_imma {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_qmma + : public Softmax_imma { + // The Traits + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Softmax_imma; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // Ctor. + template + inline __device__ Softmax_qmma(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1), + params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // scale + acc[mi][ni].ele(0) = this->elt_[2 * mi + 0][4 * ni + 0] * scale; + acc[mi][ni].ele(1) = this->elt_[2 * mi + 0][4 * ni + 1] * scale; + acc[mi][ni].ele(4) = this->elt_[2 * mi + 0][4 * ni + 2] * scale; + acc[mi][ni].ele(5) = this->elt_[2 * mi + 0][4 * ni + 3] * scale; + acc[mi][ni].ele(2) = this->elt_[2 * mi + 1][4 * ni + 0] * scale; + acc[mi][ni].ele(3) = this->elt_[2 * mi + 1][4 * ni + 1] * scale; + acc[mi][ni].ele(6) = this->elt_[2 * mi + 1][4 * ni + 2] * scale; + acc[mi][ni].ele(7) = this->elt_[2 * mi + 1][4 * ni + 3] * scale; + } + } + + // Delegate to the gmem tile to store. + // TODO: need fp32 to fp8 conversion (move this to gmem_tile) + gmem_tile.store(acc); + } + + // Convert from accumulators to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scale = reinterpret_cast(params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Convert to FP32 and scale. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scale; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scale; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scale; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scale; + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scale; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scale; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scale; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scale; + } + } + } + + template + inline __device__ void apply_exp_with_mask(float const (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float max_val = APPLY_MASK && max[mi] == -FLT_MAX + ? 0.f + : (max[mi] - logf(Traits::SOFTMAX_FP_QUANT_SCALE)); +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + this->elt_[mi][ni] = expf(this->elt_[mi][ni] - max_val); + } + } + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + float const scale = reinterpret_cast(this->params_scale_softmax_); + +// The canonical layout in K should be R0: [0,1,2,3] R2: [16,17,18,19] +// Note below that this is not possible with the register layout of the accumulator. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 8 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][8 * ki + 0] * scale; // + 0 + float tmp_01 = this->elt_[2 * mi + 0][8 * ki + 1] * scale; // + 1 + float tmp_02 = this->elt_[2 * mi + 0][8 * ki + 2] * scale; // + 8 + float tmp_03 = this->elt_[2 * mi + 0][8 * ki + 3] * scale; // + 9 + float tmp_04 = this->elt_[2 * mi + 0][8 * ki + 4] * scale; // +16 + float tmp_05 = this->elt_[2 * mi + 0][8 * ki + 5] * scale; // +17 + float tmp_06 = this->elt_[2 * mi + 0][8 * ki + 6] * scale; // +24 + float tmp_07 = this->elt_[2 * mi + 0][8 * ki + 7] * scale; // +25 + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][8 * ki + 0] * scale; // + 0 + float tmp_11 = this->elt_[2 * mi + 1][8 * ki + 1] * scale; // + 1 + float tmp_12 = this->elt_[2 * mi + 1][8 * ki + 2] * scale; // + 8 + float tmp_13 = this->elt_[2 * mi + 1][8 * ki + 3] * scale; // + 9 + float tmp_14 = this->elt_[2 * mi + 1][8 * ki + 4] * scale; // +16 + float tmp_15 = this->elt_[2 * mi + 1][8 * ki + 5] * scale; // +17 + float tmp_16 = this->elt_[2 * mi + 1][8 * ki + 6] * scale; // +24 + float tmp_17 = this->elt_[2 * mi + 1][8 * ki + 7] * scale; // +25 + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float4_to_fp8x4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = fmha::float4_to_fp8x4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[ki][mi].reg(2) = fmha::float4_to_fp8x4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[ki][mi].reg(3) = fmha::float4_to_fp8x4(tmp_14, tmp_15, tmp_16, tmp_17); + } + } + } + + // The scaling factors. + uint32_t const params_scale_bmm1_, params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_qmma + : public Softmax_imma { + // The Traits + using Traits = fmha::Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Softmax_imma; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // Ctor. + template + inline __device__ Softmax_qmma(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1), + params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // scale + acc[mi][ni].ele(0) = this->elt_[2 * mi + 0][4 * ni + 0] * scale; + acc[mi][ni].ele(1) = this->elt_[2 * mi + 0][4 * ni + 1] * scale; + acc[mi][ni].ele(4) = this->elt_[2 * mi + 0][4 * ni + 2] * scale; + acc[mi][ni].ele(5) = this->elt_[2 * mi + 0][4 * ni + 3] * scale; + acc[mi][ni].ele(2) = this->elt_[2 * mi + 1][4 * ni + 0] * scale; + acc[mi][ni].ele(3) = this->elt_[2 * mi + 1][4 * ni + 1] * scale; + acc[mi][ni].ele(6) = this->elt_[2 * mi + 1][4 * ni + 2] * scale; + acc[mi][ni].ele(7) = this->elt_[2 * mi + 1][4 * ni + 3] * scale; + } + } + + // Delegate to the gmem tile to store. + // TODO: need fp32 to fp8 conversion (move this to gmem_tile) + gmem_tile.store(acc); + } + + // Convert from accumulators to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Convert to FP32 and scale. + float2* elt_ptr0 = reinterpret_cast(this->elt_[2 * mi + 0] + 4 * ni); + float2* elt_ptr1 = reinterpret_cast(this->elt_[2 * mi + 1] + 4 * ni); + elt_ptr0[0] = fmha::half2_to_float2(fmha::hmul2(acc[mi][ni].reg(0), params_scale_bmm1_)); + elt_ptr0[1] = fmha::half2_to_float2(fmha::hmul2(acc[mi][ni].reg(2), params_scale_bmm1_)); + elt_ptr1[0] = fmha::half2_to_float2(fmha::hmul2(acc[mi][ni].reg(1), params_scale_bmm1_)); + elt_ptr1[1] = fmha::half2_to_float2(fmha::hmul2(acc[mi][ni].reg(3), params_scale_bmm1_)); + } + } + } + + // The scaling factors. + uint32_t const params_scale_bmm1_, params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + // The traits class. + using Traits = fmha::Volta_hmma_fp16_traits; + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // The number of groups of warp such that we have at most 2 warps writing consecutive elements. + enum { GROUPS = fmha::Div_up::VALUE }; + + // The number of elements that we are going to store per row. + enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS }; + + // The number of rows. + enum { ROWS = Cta_tile::M * GROUPS }; + + // The total number of elements. + enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // If shared memory is used + enum { USE_SHARED_MEMORY = Cta_tile::WARPS_N > 1 }; + + // The number of rows per thread. + enum { ROWS_PER_THREAD = MMAS_M }; + + // DEBUG. + static_assert(ELEMENTS == Cta_tile::M * Cta_tile::WARPS_N, ""); + + // END OF DEBUG. + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : params_scale_bmm1_(params.scale_bmm1), + params_softcapping_scale_bmm1_(params.softcapping_scale_bmm1), + smem_(reinterpret_cast(smem)), + tidx_(tidx) { + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Decompose the warp index into M and N. + int warp_m = warp % Cta_tile::WARPS_M; + int warp_n = warp / Cta_tile::WARPS_M; + + // Decompose the warp-n index into group/position-inside-the-group. + int warp_g = warp_n / ELEMENTS_PER_ROW; + int warp_i = warp_n % ELEMENTS_PER_ROW; + + // The row written/read by the thread (threads i and i+8 are on the same row). + int row = (lane & 0x10) / 2 + (lane & 0x07); + + // The location written by the threads. + int write_row = warp_g * Cta_tile::M + warp_m * Mma_tile::M_PER_MMA + row; + int write_col = warp_i; + + // Assemble the write pointer. + smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; + // Assemble the read pointer. + smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + row]; + } + + // Apply mask before softmax. Use 1 byte per MMA distributed as 1x8. + template + inline __device__ void apply_mask(Mask const& mask) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < 8; ++ii) { + if (!mask.is_valid(mi, ni, 0, ii)) { + elt_[mi][8 * ni + ii] = -FLT_MAX; + } + } + } + } + } + + template + inline __device__ void apply_mask_alibi(Mask const& mask, int head_id, + AlibiParams const& alibi_params) { + // 'if constexpr' because ALiBi is only defined for causal masks + if constexpr (Kernel_traits::CAUSAL_MASK) { + float m = get_alibi_head_scaling_factor(head_id, alibi_params); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < 8; ++ii) { + int row, col; + mask.get_row_col(row, col, mi, ni, 0, ii); + if (mask.is_valid(row, col)) { + // Since softmax is shift invariant, + // it is sufficient just to use the column as the multiplier + elt_[mi][8 * ni + ii] = elt_[mi][8 * ni + ii] * alibi_params.scale_after_alibi + + m * (col + alibi_params.sequence_pos_offset); + } else { + elt_[mi][8 * ni + ii] = -FLT_MAX; + } + } + } + } + } else { + __builtin_unreachable(); + } + } + + // Apply the mask to unpacked data. + inline __device__ void apply_mask(uint32_t const (&packed_mask)[MMAS_M]) { + // This code works only if we have MMAS_N <= 4. + static_assert(MMAS_N <= 4, ""); + + // Expand the mask. + int mask[MMAS_M][MMAS_N * 8]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < MMAS_N * 8; ++ii) { + mask[mi][ii] = packed_mask[mi] & (1u << ii); + } + } + +// Apply the mask. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + if (!mask[mi][ni]) { + elt_[mi][ni] = -FLT_MAX; + } + } + } + } + + // Mask the elements that are outside the the sequence length. + inline __device__ void apply_mask(int const seqlen) { + // The warp/lane decomposition. + int const warp = threadIdx.x / Cta_tile::THREADS_PER_WARP; + int const lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; + + // The warp in the n dimension. + int const warp_n = warp / Cta_tile::WARPS_M; + // The base position within a quad. + int const offset = warp_n * 16 + (threadIdx.x & 0x08) / 2; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // The position in the sequence. + int pos = offset + ni * Mma_tile::N_PER_MMA_PER_CTA; + + // Determine the position in the sequence. + if (pos + 0 >= seqlen) { + elt_[mi][8 * ni + 0] = -FLT_MAX; + } + if (pos + 1 >= seqlen) { + elt_[mi][8 * ni + 1] = -FLT_MAX; + } + if (pos + 2 >= seqlen) { + elt_[mi][8 * ni + 2] = -FLT_MAX; + } + if (pos + 3 >= seqlen) { + elt_[mi][8 * ni + 3] = -FLT_MAX; + } + if (pos + 8 >= seqlen) { + elt_[mi][8 * ni + 4] = -FLT_MAX; + } + if (pos + 9 >= seqlen) { + elt_[mi][8 * ni + 5] = -FLT_MAX; + } + if (pos + 10 >= seqlen) { + elt_[mi][8 * ni + 6] = -FLT_MAX; + } + if (pos + 11 >= seqlen) { + elt_[mi][8 * ni + 7] = -FLT_MAX; + } + } + } + } + + // Apply the exp to all the elements. + // Need to make sure the results are zero when all elts are -FLT_MAX + // as it is possible that all tokens are masked out. + template + inline __device__ void apply_exp_with_mask(float const (&max)[MMAS_M]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + float max_val = APPLY_MASK && max[mi] == -FLT_MAX ? 0.f : max[mi]; +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + this->elt_[mi][ni] = expf(this->elt_[mi][ni] - max_val); + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const max) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + elt_[mi][ni] = apply_exp_(elt_[mi][ni], max); + } + } + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const (&max)[MMAS_M]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); + } + } + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(MMAS_M == M && MMAS_N == K, ""); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 8 elements per row. + float tmp_0 = this->elt_[mi][8 * ki + 0]; + float tmp_1 = this->elt_[mi][8 * ki + 1]; + float tmp_2 = this->elt_[mi][8 * ki + 2]; + float tmp_3 = this->elt_[mi][8 * ki + 3]; + float tmp_4 = this->elt_[mi][8 * ki + 4]; + float tmp_5 = this->elt_[mi][8 * ki + 5]; + float tmp_6 = this->elt_[mi][8 * ki + 6]; + float tmp_7 = this->elt_[mi][8 * ki + 7]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_0, tmp_1); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_2, tmp_3); + dst[ki][mi].reg(2) = fmha::float2_to_16bit_2(tmp_4, tmp_5); + dst[ki][mi].reg(3) = fmha::float2_to_16bit_2(tmp_6, tmp_7); + } + } + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce_Nx1(float (&dst)[MMAS_M]) { +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The thread local math in the reference code. + float sums[MMAS_N * 2]; +#pragma unroll + for (int ii = 0; ii < MMAS_N * 2; ++ii) { + sums[ii] = elt_[mi][4 * ii + 0]; + sums[ii] += elt_[mi][4 * ii + 1]; + sums[ii] += elt_[mi][4 * ii + 2]; + sums[ii] += elt_[mi][4 * ii + 3]; + } + +// Columns 0 and 8: __shfl( 2). +#pragma unroll + for (int ii = 0; ii < MMAS_N; ++ii) { + sums[2 * ii] += sums[2 * ii + 1]; + } + +// Columns 0 and 32: __shfl( 8). +#pragma unroll + for (int ii = 0; ii < MMAS_N / 2; ++ii) { // MMAS_N / 2 == 0 if MMAS_N <= 1. + sums[4 * ii] += sums[4 * ii + 2]; + } + + // Columns 0 and 64: __shfl(16). + if (MMAS_N == 3) { + sums[0] += sums[4]; + } else if (MMAS_N >= 4) { +#pragma unroll + for (int ii = 0; ii < MMAS_N / 4; ++ii) { // MMAS_N / 4 == 0 if MMAS_N <= 2. + sums[8 * ii] += sums[8 * ii + 4]; + } + } + + // Store the final value for that row. + dst[mi] = sums[0]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 8; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 8)); + } + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_2x2() { + float dst[MMAS_M]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The thread local math in the reference code. + float sums[MMAS_N * 2]; +#pragma unroll + for (int ii = 0; ii < MMAS_N * 2; ++ii) { + sums[ii] = elt_[mi][4 * ii + 0]; + sums[ii] += elt_[mi][4 * ii + 1]; + sums[ii] += elt_[mi][4 * ii + 2]; + sums[ii] += elt_[mi][4 * ii + 3]; + } + +// Columns 0 and 8: __shfl( 2). +#pragma unroll + for (int ii = 0; ii < MMAS_N; ++ii) { + sums[2 * ii] += sums[2 * ii + 1]; + } + +// Columns 0 and 32: __shfl( 8). +#pragma unroll + for (int ii = 0; ii < MMAS_N / 2; ++ii) { // MMAS_N / 2 == 0 if MMAS_N <= 1. + sums[4 * ii] += sums[4 * ii + 2]; + } + + // Columns 0 and 64: __shfl(16). + if (MMAS_N == 3) { + sums[0] += sums[4]; + } else if (MMAS_N >= 4) { +#pragma unroll + for (int ii = 0; ii < MMAS_N / 4; ++ii) { // MMAS_N / 4 == 0 if MMAS_N <= 2. + sums[8 * ii] += sums[8 * ii + 4]; + } + } + + // Store the final value for that row. + dst[mi] = sums[0]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 8; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 8)); + } + +// Store the different values to shared memory. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 16 < 8) { + smem_write_[mi * Mma_tile::M_PER_MMA_PER_CTA * ELEMENTS_PER_ROW] = dst[mi]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 2 values (one for each warp). + float2 tmp = reinterpret_cast(smem_)[tidx_]; + + // Compute the reduction of those 2 values in a binary-tree fashion. + return Functor::apply(tmp.x, tmp.y); + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_1x4() { + float dst[MMAS_M]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The thread local math in the reference code. + float sums[MMAS_N * 2]; +#pragma unroll + for (int ii = 0; ii < MMAS_N * 2; ++ii) { + sums[ii] = elt_[mi][4 * ii + 0]; + sums[ii] += elt_[mi][4 * ii + 1]; + sums[ii] += elt_[mi][4 * ii + 2]; + sums[ii] += elt_[mi][4 * ii + 3]; + } + + // Columns 0 and 128 (the ref code uses a step of 128). Not needed if SEQLEN <= 128. + if (Cta_tile::N > 128) { +#pragma unroll + for (int ii = 0; ii < MMAS_N; ++ii) { + sums[ii] += sums[MMAS_N + ii]; + } + } + +// Columns 0 and 8: __shfl( 2). +#pragma unroll + for (int ii = 0; ii < MMAS_N; ++ii) { + sums[2 * ii] += sums[2 * ii + 1]; + } + +// Columns 0 and 64: __shfl(16). +#pragma unroll + for (int ii = 0; ii < MMAS_N / 2; ++ii) { // MMAS_N / 2 == 0 if MMAS_N <= 1. + sums[4 * ii] += sums[4 * ii + 2]; + } + + // Store the final value for that row. + dst[mi] = sums[0]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 8; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 8)); + } + +// Store the different values to shared memory. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 16 < 8) { + smem_write_[mi * Mma_tile::M_PER_MMA_PER_CTA * ELEMENTS_PER_ROW] = dst[mi]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 4 values (one for each warp). + float2 tmp[2]; + if (tidx_ < Cta_tile::M) { + tmp[0] = reinterpret_cast(&smem_[0 * ELEMENTS / 2])[tidx_]; + tmp[1] = reinterpret_cast(&smem_[1 * ELEMENTS / 2])[tidx_]; + } + + // Compute the reduction of those 4 values in a binary-tree fashion. + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y); + tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y); + tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x); + + // Return the final reduction. + return tmp[0].x; + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_1x8() { + float dst[MMAS_M]; +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + if (Functor::IS_SUM) { +// Apply the summation inside the thread for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // The thread local math in the reference code. + float sums[MMAS_N * 2]; +#pragma unroll + for (int ii = 0; ii < MMAS_N * 2; ++ii) { + sums[ii] = elt_[mi][4 * ii + 0]; + sums[ii] += elt_[mi][4 * ii + 1]; + sums[ii] += elt_[mi][4 * ii + 2]; + sums[ii] += elt_[mi][4 * ii + 3]; + } + +// Columns 0 and 128 (the ref code uses a step of 128). Not needed if SEQLEN <= 128. +#pragma unroll + for (int ii = 1; ii < MMAS_N; ++ii) { + sums[0] += sums[2 * ii + 0]; + sums[1] += sums[2 * ii + 1]; + } + + // Columns 0 and 8: __shfl( 2). + dst[mi] = sums[0] + sums[1]; + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * 8; ++ni) { + dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]); + } + } + } + +// Apply the functor for each row. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 8)); + } + +// Store the different values to shared memory. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + if (tidx_ % 16 < 8) { + smem_write_[mi * Mma_tile::M_PER_MMA_PER_CTA * ELEMENTS_PER_ROW] = dst[mi]; + } + } + + // Make sure the values are in shared memory. + __syncthreads(); + + // Load 8 values (one for each warp). + float2 tmp[4]; + if (tidx_ < Cta_tile::M) { + tmp[0] = reinterpret_cast(&smem_[0 * ELEMENTS / 4])[tidx_]; + tmp[1] = reinterpret_cast(&smem_[1 * ELEMENTS / 4])[tidx_]; + tmp[2] = reinterpret_cast(&smem_[2 * ELEMENTS / 4])[tidx_]; + tmp[3] = reinterpret_cast(&smem_[3 * ELEMENTS / 4])[tidx_]; + } + + // // DEBUG. + // if( tidx_ == 0 ) { + // #pragma unroll + // for( int ii = 0; ii < 4; ++ii ) { + // printf("tidx=%3d tmp[%d]=%8.3f %8.3f\n", tidx_, ii, tmp[ii].x, tmp[ii].y); + // } + // } + // // END OF DEBUG. + + // Compute the reduction of those 8 values in a binary-tree fashion. + tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y); + tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y); + tmp[2].x = Functor::apply(tmp[2].x, tmp[2].y); + tmp[3].x = Functor::apply(tmp[3].x, tmp[3].y); + + tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x); + tmp[2].x = Functor::apply(tmp[2].x, tmp[3].x); + + tmp[0].x = Functor::apply(tmp[0].x, tmp[2].x); + + // Return the final reduction. + return tmp[0].x; + } + + // Do a CTA-wide reduction. + template + inline __device__ float reduce_() { + // The final reduction. + float red = 0.f; + + // SEQLEN == 128. + if (Cta_tile::WARPS_M == 2 && Cta_tile::WARPS_N == 2) { + red = reduce_2x2(); + + // SEQLEN == 256. + } else if (Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4) { + red = reduce_1x4(); + + // SEQLEN == 256. + } else if (Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8) { + red = reduce_1x8(); + + // Not supported. + } else { + assert(false); + } + + return red; + } + + // Finalize the reduction. + inline __device__ void shuffle(float (&dst)[MMAS_M], float red) { + // Store the value back to shared memory. + if (tidx_ < Cta_tile::M) { + smem_[tidx_] = red; + } + + // Make sure the data is in shared memory. + __syncthreads(); + +// Finally read the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + dst[mi] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA]; + } + + // Make sure we are done reading shared memory. + __syncthreads(); + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce(float (&dst)[MMAS_M]) { + // NOTE: 1 warp along reduce direction, no syncs + if (Cta_tile::WARPS_N == 1) { + reduce_Nx1(dst); + } else { + // The result of the reduction. Threads 0..Cta_tile::M-1 own a valid value. + float red = reduce_(); + + // Make sure we can write to shared memory. + __syncthreads(); + + // Finalize the reduction. + shuffle(dst, red); + } + } + + // Scale all the elements. + inline __device__ void scale(float const (&sum)[MMAS_M]) { + // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. + float inv_sum[MMAS_M]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + } + +// Update the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * 8; ++ni) { + elt_[mi][ni] *= inv_sum[mi]; + } + } + } + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // The elements. + float tmp_00 = this->elt_[mi][8 * ni + 0]; + float tmp_01 = this->elt_[mi][8 * ni + 1]; + float tmp_02 = this->elt_[mi][8 * ni + 2]; + float tmp_03 = this->elt_[mi][8 * ni + 3]; + float tmp_04 = this->elt_[mi][8 * ni + 4]; + float tmp_05 = this->elt_[mi][8 * ni + 5]; + float tmp_06 = this->elt_[mi][8 * ni + 6]; + float tmp_07 = this->elt_[mi][8 * ni + 7]; + + // Transform to accumulators. + acc[mi][ni].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + acc[mi][ni].reg(1) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + acc[mi][ni].reg(2) = fmha::float2_to_16bit_2(tmp_04, tmp_05); + acc[mi][ni].reg(3) = fmha::float2_to_16bit_2(tmp_06, tmp_07); + } + } + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Normalize the values, and clamp to finite half. + uint32_t acc_0 = satfinite_h2(hmul2(acc[mi][ni].reg(0), params_scale_bmm1_)); + uint32_t acc_1 = satfinite_h2(hmul2(acc[mi][ni].reg(1), params_scale_bmm1_)); + uint32_t acc_2 = satfinite_h2(hmul2(acc[mi][ni].reg(2), params_scale_bmm1_)); + uint32_t acc_3 = satfinite_h2(hmul2(acc[mi][ni].reg(3), params_scale_bmm1_)); + + // Extract the values as floats. + half2_to_float2(this->elt_[mi][8 * ni + 0], this->elt_[mi][8 * ni + 1], acc_0); + half2_to_float2(this->elt_[mi][8 * ni + 2], this->elt_[mi][8 * ni + 3], acc_1); + half2_to_float2(this->elt_[mi][8 * ni + 4], this->elt_[mi][8 * ni + 5], acc_2); + half2_to_float2(this->elt_[mi][8 * ni + 6], this->elt_[mi][8 * ni + 7], acc_3); + + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { +#pragma unroll + for (int i = 0; i < 8; i++) { + // 1.0f / softcapping_scale has been fused to scale_bmm1. + this->elt_[mi][8 * ni + i] = + params_softcapping_scale_bmm1_ * __tanhf(this->elt_[mi][8 * ni + i]); + } + } + } + } + } + + // The scaling factor. + uint32_t const params_scale_bmm1_; + float const params_softcapping_scale_bmm1_; + // Shared memory for the CTA-wide reduction. + float *smem_, *smem_write_, *smem_read_; + // The current thread index. + int tidx_; + // The elements. + float elt_[MMAS_M][MMAS_N * 8]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_hmma { + // The traits. + using Traits = fmha::Turing_hmma_fp16_traits; + // The base class. + using Base = Softmax_hmma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(Base::Mma_tile::MMAS_M == M && Base::Mma_tile::MMAS_N * 4 == K * 2, ""); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 2 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][2 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][2 * ki + 1]; + + // 2nd row - 2 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][2 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][2 * ki + 1]; + + // Pack to 2 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_imma { + // The traits. + using Traits = fmha::Volta_imma_int8_int32_traits; + // The base class. + using Base = Softmax_imma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Repack. We could use store/load to match the Smem_tile API. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) { + static_assert(Base::Mma_tile::MMAS_M == M && Base::Mma_tile::MMAS_N == K, ""); + float const scale = reinterpret_cast(this->params_scale_softmax_); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0] * scale; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1] * scale; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2] * scale; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3] * scale; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0] * scale; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1] * scale; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2] * scale; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3] * scale; + + // Pack to 2 registers. + dst[ki][mi].reg(0) = float4_to_char4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = float4_to_char4(tmp_10, tmp_11, tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_imma { + // The traits. + using Traits = fmha::Turing_imma_int8_int32_traits; + // The base class. + using Base = Softmax_imma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Repack. We could use store/load to match the Smem_tile API. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) { + static_assert(Base::Mma_tile::MMAS_M == M && Base::Mma_tile::MMAS_N == K, ""); + float const scale = reinterpret_cast(this->params_scale_softmax_); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0] * scale; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1] * scale; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2] * scale; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3] * scale; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0] * scale; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1] * scale; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2] * scale; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3] * scale; + + // Pack to 2 registers. + dst[ki][mi].reg(0) = float4_to_char4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = float4_to_char4(tmp_10, tmp_11, tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_hmma { + // The traits. + using Traits = fmha::Ampere_hmma_fp16_traits; + // The base class. + using Base = Softmax_hmma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_fp32 : public Softmax_hmma { + // The base class. + using Base = Softmax_hmma; + // The fragment. + using Fragment_a = fmha::Fragment_a; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // Output accumulators (after conversion). + using Accumulator_out = fmha::Fragment_accumulator; + + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // DEBUG. + static_assert(Accumulator_out::NUM_REGS == 4, ""); + // END OF DEBUG. + + // DEBUG. + static_assert(std::is_same::value, ""); + + // END OF DEBUG. + + enum { WARPS_M = Cta_tile::WARPS_M }; + + enum { WARPS_N = Cta_tile::WARPS_N }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + using Smem_tile_red = Smem_tile_reduce; + static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); + + // Ctor. + template + inline __device__ Softmax_fp32(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + smem_sum_(static_cast(smem), tidx), + smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator_out acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // The elements. + float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3]; + float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3]; + + // Transform to accumulators. + acc[mi][ni].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + acc[mi][ni].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + acc[mi][ni].reg(2) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + acc[mi][ni].reg(3) = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(Fragment_a::NUM_REGS == 4, ""); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_16bit_2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + } + + // Pack the data to a uint4 for the next operation. + template + inline __device__ void pack(uint4 (&dst)[M][N]) const { +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ni = 0; ni < N; ++ni) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3]; + + // Pack to 4 registers. + dst[mi][ni].x = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[mi][ni].y = fmha::float2_to_16bit_2(tmp_02, tmp_03); + dst[mi][ni].z = fmha::float2_to_16bit_2(tmp_10, tmp_11); + dst[mi][ni].w = fmha::float2_to_16bit_2(tmp_12, tmp_13); + } + } + } + + // Scale FP32 fragments + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scalef = reinterpret_cast(this->params_scale_bmm1_); + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef; + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef; + + // Attention logit softcapping scale. + // 1.0f / softcapping_scale has been fused to scale_bmm1. + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + this->elt_[2 * mi + 0][4 * ni + 0] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 0]); + this->elt_[2 * mi + 0][4 * ni + 1] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 1]); + this->elt_[2 * mi + 1][4 * ni + 0] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 0]); + this->elt_[2 * mi + 1][4 * ni + 1] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 1]); + this->elt_[2 * mi + 0][4 * ni + 2] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 2]); + this->elt_[2 * mi + 0][4 * ni + 3] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 0][4 * ni + 3]); + this->elt_[2 * mi + 1][4 * ni + 2] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 2]); + this->elt_[2 * mi + 1][4 * ni + 3] = + this->params_softcapping_scale_bmm1_ * __tanhf(this->elt_[2 * mi + 1][4 * ni + 3]); + } + } + } + } + + // Scale FP32 fragments + inline __device__ void unpack_noscale(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4); + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5); + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2); + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); + } + } + } + + template + __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator& op, Smem_tile_red& smem_red) { +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + frag[mi] = this->elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < 4 * MMAS_N; ni++) { + frag[mi] = op(frag[mi], this->elt_[mi][ni]); + } + } + quad_reduce(frag, frag, op); + + if (WARPS_N > 1) { + smem_red.store(frag); + __syncthreads(); + typename Smem_tile_red::read_t tmp[2 * MMAS_M]; + smem_red.load(tmp); + + quad_allreduce(frag, tmp, op); + } + } + + __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]) { + MaxOp max; + reduce_(frag, max, smem_max_); + } + + __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]) { + SumOp sum; + reduce_(frag, sum, smem_sum_); + } + + __device__ inline float correct(float warp_sum, float warp_max, float max) { + return warp_sum * __expf(warp_max - max); + } + + __device__ inline float2 correct(float2 warp_sum, float2 warp_max, float max) { + return {correct(warp_sum.x, warp_max.x, max), correct(warp_sum.y, warp_max.y, max)}; + } + + __device__ inline void online_softmax() { + MaxOp maxOp; + SumOp sumOp; + float max[2 * MMAS_M]; +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + max[mi] = this->elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < 4 * MMAS_N; ni++) { + max[mi] = maxOp(max[mi], this->elt_[mi][ni]); + } + } + quad_allreduce(max, max, maxOp); + smem_max_.store(max); + float sum[2 * MMAS_M]; +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + sum[mi] = 0.f; +#pragma unroll + for (int ni = 0; ni < 4 * MMAS_N; ni++) { + float x = this->elt_[mi][ni]; + this->elt_[mi][ni] = __expf(x - max[mi]); + sum[mi] += this->elt_[mi][ni]; + } + } + quad_allreduce(sum, sum, sumOp); + smem_sum_.store(sum); + + __syncthreads(); + + typename Smem_tile_red::read_t tmp_max[2 * MMAS_M]; + typename Smem_tile_red::read_t tmp_sum[2 * MMAS_M]; + smem_max_.load(tmp_max); + smem_sum_.load(tmp_sum); + float full_max[2 * MMAS_M]; + quad_allreduce(full_max, tmp_max, maxOp); +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + tmp_sum[mi] = correct(tmp_sum[mi], tmp_max[mi], full_max[mi]); + } + quad_allreduce(sum, tmp_sum, sumOp); +#pragma unroll + for (int mi = 0; mi < 2 * MMAS_M; mi++) { + float correction = __expf(max[mi] - full_max[mi]) / sum[mi]; +#pragma unroll + for (int ni = 0; ni < 4 * MMAS_N; ni++) { + this->elt_[mi][ni] *= correction; + } + } + } + + Smem_tile_red smem_max_; + Smem_tile_red smem_sum_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_fp32_traits; + // The base class. + using Base = Softmax_fp32; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_fp32 { + // The traits. + using Traits = fmha::Turing_hmma_fp32_traits; + // The base class. + using Base = Softmax_fp32; + // The fragment. + using Fragment_a = fmha::Fragment_a; + // Softmax dst data_type (BMM2 input) + using Dst_type = typename Traits::A_type; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(Fragment_a::NUM_REGS == 2, ""); + static_assert(Base::Mma_tile::MMAS_M == M && Base::Mma_tile::MMAS_N * 4 == K * 2, ""); +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 2 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][2 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][2 * ki + 1]; + + // 2nd row - 2 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][2 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][2 * ki + 1]; + + // Pack to 2 registers. + dst[ki][mi].reg(0) = fmha::float2_to_16bit_2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_16bit_2(tmp_10, tmp_11); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_fp32 { + // The traits. + using Traits = fmha::Ampere_hmma_bf16_traits; + // The base class. + using Base = Softmax_fp32; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_imma { + // The traits. + using Traits = fmha::Ampere_imma_int8_int32_traits; + // The base class. + using Base = Softmax_imma; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_qmma { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Softmax_qmma; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_qmma { + // The traits. + using Traits = fmha::Ada_qmma_e4m3_fp16_traits; + // The base class. + using Base = Softmax_qmma; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax + : public Softmax_imma { + // The Traits + using Traits = fmha::Ada_qmma_e4m3_fp32_traits; + // The base class. + using Base = Softmax_imma; + + // The MMAs. + enum { MMAS_M = Base::MMAS_M }; + + enum { MMAS_N = Base::MMAS_N }; + + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), + params_scale_bmm1_(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1), + params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // scale + acc[mi][ni].ele(0) = this->elt_[2 * mi + 0][4 * ni + 0] * scale; + acc[mi][ni].ele(1) = this->elt_[2 * mi + 0][4 * ni + 1] * scale; + acc[mi][ni].ele(4) = this->elt_[2 * mi + 0][4 * ni + 2] * scale; + acc[mi][ni].ele(5) = this->elt_[2 * mi + 0][4 * ni + 3] * scale; + acc[mi][ni].ele(2) = this->elt_[2 * mi + 1][4 * ni + 0] * scale; + acc[mi][ni].ele(3) = this->elt_[2 * mi + 1][4 * ni + 1] * scale; + acc[mi][ni].ele(6) = this->elt_[2 * mi + 1][4 * ni + 2] * scale; + acc[mi][ni].ele(7) = this->elt_[2 * mi + 1][4 * ni + 3] * scale; + } + } + + // Delegate to the gmem tile to store. + // TODO: need fp32 to fp8 conversion (move this to gmem_tile) + gmem_tile.store(acc); + } + + // Convert from accumulators to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scale = params_scale_q_ * params_scale_k_; +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { + // Convert to FP32 and scale. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scale; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scale; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scale; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scale; + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scale; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scale; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scale; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scale; + } + } + } + + template + inline __device__ void apply_exp_with_mask(float const (&max)[MMAS_M * 2]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * 2; ++mi) { + float max_val = APPLY_MASK && max[mi] == -FLT_MAX + ? 0.f + : (max[mi] - logf(Traits::SOFTMAX_FP_QUANT_SCALE)); +#pragma unroll + for (int ni = 0; ni < MMAS_N * 4; ++ni) { + this->elt_[mi][ni] = expf(this->elt_[mi][ni] - max_val); + } + } + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + float const scale = reinterpret_cast(this->params_scale_softmax_); + +// The canonical layout in K should be R0: [0,1,2,3] R2: [16,17,18,19] +// Note below that this is not possible with the register layout of the accumulator. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 8 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][8 * ki + 0] * scale; // + 0 + float tmp_01 = this->elt_[2 * mi + 0][8 * ki + 1] * scale; // + 1 + float tmp_02 = this->elt_[2 * mi + 0][8 * ki + 2] * scale; // + 8 + float tmp_03 = this->elt_[2 * mi + 0][8 * ki + 3] * scale; // + 9 + float tmp_04 = this->elt_[2 * mi + 0][8 * ki + 4] * scale; // +16 + float tmp_05 = this->elt_[2 * mi + 0][8 * ki + 5] * scale; // +17 + float tmp_06 = this->elt_[2 * mi + 0][8 * ki + 6] * scale; // +24 + float tmp_07 = this->elt_[2 * mi + 0][8 * ki + 7] * scale; // +25 + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][8 * ki + 0] * scale; // + 0 + float tmp_11 = this->elt_[2 * mi + 1][8 * ki + 1] * scale; // + 1 + float tmp_12 = this->elt_[2 * mi + 1][8 * ki + 2] * scale; // + 8 + float tmp_13 = this->elt_[2 * mi + 1][8 * ki + 3] * scale; // + 9 + float tmp_14 = this->elt_[2 * mi + 1][8 * ki + 4] * scale; // +16 + float tmp_15 = this->elt_[2 * mi + 1][8 * ki + 5] * scale; // +17 + float tmp_16 = this->elt_[2 * mi + 1][8 * ki + 6] * scale; // +24 + float tmp_17 = this->elt_[2 * mi + 1][8 * ki + 7] * scale; // +25 + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float4_to_fp8x4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = fmha::float4_to_fp8x4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[ki][mi].reg(2) = fmha::float4_to_fp8x4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[ki][mi].reg(3) = fmha::float4_to_fp8x4(tmp_14, tmp_15, tmp_16, tmp_17); + } + } + } + + template + inline __device__ void move_to_first_block(Params const& params, int bidb, int bidh, int q_loop) { + int scale_q_iter = + bidb * params.h * params.sage.q.max_nblock + bidh * params.sage.q.max_nblock + q_loop; + params_scale_q_ = __ldg(params.sage.q.scales + scale_q_iter); + params_scale_q_ *= reinterpret_cast(params_scale_bmm1_); + + int scale_k_iter = bidb * params.h * params.sage.k.max_nblock + bidh * params.sage.k.max_nblock; + params_scale_k_iter = reinterpret_cast(params.sage.k.scales + scale_k_iter); + params_scale_k_ = __ldg(params_scale_k_iter); + } + + inline __device__ void move_to_next_block() { + params_scale_k_iter += 1; + params_scale_k_ = __ldg(params_scale_k_iter); + } + + // The scaling factors. + uint32_t const params_scale_bmm1_, params_scale_softmax_; + float params_scale_q_, params_scale_k_; + float const* params_scale_k_iter; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// HOPPER SOFTMAX + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_gmma_base {}; + +template +struct Softmax_gmma_base { + // The instruction traits. + using Traits = Traits_; + // The Cta_tile. + using Cta_tile = Cta_tile_; + // The Kernel traits. + using Kernel_traits = Kernel_traits_; + // The accumulators. + using Accumulator = fmha::Fragment_accumulator; + // The Mma tile. + using Mma_tile = typename Traits::template Mma_tile; + + static_assert(Cta_tile::WARPS_M == 4); + static_assert(Mma_tile::M_PER_MMA_PER_CTA == 64); + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // Elements per thread per core matrix. + enum { ELTS_PER_THREAD = 2 }; + + // Core matrix is always 8x4. + enum { THREADS_PER_ROW = 4 }; + + enum { SMEM_BYTES = 0 }; + + // The number of rows accessed by each thread. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + static_assert(ROWS_PER_THREAD == Mma_tile::ROWS_PER_THREAD); + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // The number of total elements per thread. + enum { TOTAL_ELTS_PER_THREAD = ELTS_PER_THREAD * COLS_PER_THREAD }; + + template + inline __device__ Softmax_gmma_base(Params const& params, void*, int const, int const) + : params_scale_bmm1_(params.scale_bmm1), + params_softcapping_scale_bmm1_(params.softcapping_scale_bmm1) {} + + // Apply mask before softmax. Use 1 byte per MMA distributed as 2x4. + template + inline __device__ void apply_mask(Mask const& mask) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < ROWS_PER_THREAD; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < TOTAL_ELTS_PER_THREAD; ++jj) { + if (!mask.is_valid(mi, ni, ii, jj)) { + this->elt_[ROWS_PER_THREAD * mi + ii][TOTAL_ELTS_PER_THREAD * ni + jj] = -FLT_MAX; + } + } // jj + } // ni + } // ii + } // mi + } + + template + inline __device__ void apply_mask_alibi(Mask const& mask, int head_id, + AlibiParams const& alibi_params) { + // 'if constexpr' because ALiBi is only defined for causal masks + if constexpr (Kernel_traits::CAUSAL_MASK) { + float m = get_alibi_head_scaling_factor(head_id, alibi_params); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < ROWS_PER_THREAD; ++ii) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int jj = 0; jj < TOTAL_ELTS_PER_THREAD; ++jj) { + int row, col; + mask.get_row_col(row, col, mi, ni, ii, jj); + if (mask.is_valid(row, col)) { + // Since softmax is shift invariant, + // it is sufficient just to use the column as the multiplier + elt_[ROWS_PER_THREAD * mi + ii][TOTAL_ELTS_PER_THREAD * ni + jj] = + elt_[ROWS_PER_THREAD * mi + ii][TOTAL_ELTS_PER_THREAD * ni + jj] * + alibi_params.scale_after_alibi + + m * (col + alibi_params.sequence_pos_offset); + } else { + elt_[ROWS_PER_THREAD * mi + ii][TOTAL_ELTS_PER_THREAD * ni + jj] = -FLT_MAX; + } + } + } + } + } + } else { + __builtin_unreachable(); + } + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce_4x1(float (&dst)[MMAS_M * ROWS_PER_THREAD]) { +#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD == MMAS_N * Mma_tile::CORES_N * 2); + if (Functor::IS_SUM) { +// Apply the summation inside the thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { + dst[mi] = (this->elt_[mi][0] + this->elt_[mi][1]); +#pragma unroll + for (int ni = 1; ni < MMAS_N * Mma_tile::CORES_N; ni++) { + dst[mi] += (this->elt_[mi][ni * 2 + 0] + this->elt_[mi][ni * 2 + 1]); + } + } + } else +#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE) + { +// find the max/sum for each row. +// For hopper, each row is held entirely within 4 threads. +// Apply the functor for each row inside a thread. +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { + dst[mi] = this->elt_[mi][0]; +#pragma unroll + for (int ni = 1; ni < MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD; ++ni) { + dst[mi] = Functor::apply(dst[mi], this->elt_[mi][ni]); + } + } + } +// Apply the functor for each row inside each group of 4 threads. +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1)); + __syncwarp(); + dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2)); + __syncwarp(); + } + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce(float (&dst)[MMAS_M * ROWS_PER_THREAD]) { + reduce_4x1(dst); + } + + // Apply the exp to all the elements. + inline __device__ void apply_exp(float const (&max)[MMAS_M * ROWS_PER_THREAD]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD; ++ni) { + this->elt_[mi][ni] = apply_exp_(this->elt_[mi][ni], max[mi]); + } + } + } + + // Scale all the elements. + inline __device__ void scale(float const (&sum)[MMAS_M * ROWS_PER_THREAD]) { + // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. + float inv_sum[MMAS_M * ROWS_PER_THREAD]; +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { + inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; + } + +// Update the values. +#pragma unroll + for (int mi = 0; mi < MMAS_M * ROWS_PER_THREAD; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD; ++ni) { + this->elt_[mi][ni] *= inv_sum[mi]; + } + } + } + + // The scalig factor. Depens on acc type, e.g. float for 32-bit and fp16x2/bf16x2 for 16-bit. + uint32_t const params_scale_bmm1_; + float const params_softcapping_scale_bmm1_; + // The elements. + float elt_[MMAS_M * ROWS_PER_THREAD][MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD]; +}; + +template +struct Softmax_gmma_base + : public Softmax_gmma_base { + using Base = Softmax_gmma_base; + + using Mma_tile = typename Base::Mma_tile; + + enum { BYTES_PER_SMEM = Mma_tile::M_PER_MMA_PER_CTA * Cta_tile::WARPS_N * sizeof(float) }; + + enum { ELTS_PER_ROW = 2 }; + + static_assert(Cta_tile::WARPS_N == 2); + static_assert(Cta_tile::WARPS_M == 4); + static_assert(Mma_tile::M_PER_MMA_PER_CTA == 64); + + template + inline __device__ Softmax_gmma_base(Params const& params, void* smem, int const bidb, + int const tidx) + : Base(params, smem, bidb, tidx) { + int const warp = tidx / Cta_tile::THREADS_PER_WARP; + int const warp_n = warp / 4; + int const warp_m = warp % 4; + int const lane = tidx % Cta_tile::THREADS_PER_WARP; + int const quad = lane / 4; + is_writer_ = lane % 4 == 0; + + int const col = warp_n; + int const row = warp_m * 16 + quad; + + smem_write_ = static_cast(smem) + row * 2 + col; + smem_read_ = static_cast(smem) + row; + } + + // Do a CTA-wide reduction. + template + inline __device__ void reduce(float (&dst)[2]) { + Base::template reduce_4x1(dst); + if (is_writer_) { + smem_write_[0 * ELTS_PER_ROW] = dst[0]; + smem_write_[8 * ELTS_PER_ROW] = dst[1]; + } + __syncthreads(); + float2 tmp0 = smem_read_[0]; + float2 tmp1 = smem_read_[8]; + dst[0] = Functor::apply(tmp0.x, tmp0.y); + dst[1] = Functor::apply(tmp1.x, tmp1.y); + } + + float* smem_write_; + float2* smem_read_; + bool is_writer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile_, Kernel_traits_> + : public Softmax_gmma_base< + fmha::Hopper_hgmma_fp16_traits, Cta_tile_, + Kernel_traits_, Cta_tile_::WARPS_N> { + // The traits. + using Traits = fmha::Hopper_hgmma_fp16_traits; + // Cta_tile. + using Cta_tile = Cta_tile_; + // Kernel_traits. + using Kernel_traits = Kernel_traits_; + // The Base class. + using Base = Softmax_gmma_base; + // The accumulators. + using Accumulator = typename Base::Accumulator; + // The Mma tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // for HGMMA_FP16, there are 2 elements per RF for ACC. + enum { ELTS_PER_THREAD = 2 }; + + // for Hopper HGMMA, each row is held within 4 threads. + enum { THREADS_PER_ROW = 4 }; + + // The number of rows accessed by each thread. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + + // Normalize the values. + uint32_t acc_0 = fmha::hmul2(acc[mi][ni].reg(col_idx * ROWS_PER_THREAD + row_idx), + this->params_scale_bmm1_); + // Element index. + int elt_row_idx = ROWS_PER_THREAD * mi + row_idx; + int elt_col_idx = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + // Extract the values as floats. + half2_to_float2(this->elt_[elt_row_idx][elt_col_idx + 0], + this->elt_[elt_row_idx][elt_col_idx + 1], acc_0); + // Attention logit softcapping scale. + // 1.0f / softcapping_scale has been fused to scale_bmm1. + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + this->elt_[elt_row_idx][elt_col_idx + 0] = + this->params_softcapping_scale_bmm1_ * + __tanhf(this->elt_[elt_row_idx][elt_col_idx + 0]); + this->elt_[elt_row_idx][elt_col_idx + 1] = + this->params_softcapping_scale_bmm1_ * + __tanhf(this->elt_[elt_row_idx][elt_col_idx + 1]); + } + } // row_idx + } // col_idx + } // ni + } // mi + } + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + float tmp_00 = + this->elt_[ROWS_PER_THREAD * mi + row_idx] + [COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD + 0]; + float tmp_01 = + this->elt_[ROWS_PER_THREAD * mi + row_idx] + [COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD + 1]; + acc[mi][ni].reg(col_idx * ROWS_PER_THREAD + row_idx) = + fmha::float2_to_half2(tmp_00, tmp_01); + } // row_idx + } // col_idx + } // ni + } // m + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +// we know the instruction shape is 64xNx16 +// Thus for input A matrix, it is of size 64x16 per warpgroup. +// Thus, each threads access 2 rows and 4 columns. contiguous 2 columns are held by 1 RF. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile_, Kernel_traits_> + : public Softmax_gmma_base< + fmha::Hopper_hgmma_fp32_traits, Cta_tile_, + Kernel_traits_, Cta_tile_::WARPS_N> { + // The traits. + using Traits = fmha::Hopper_hgmma_fp32_traits; + // Cta_tile. + using Cta_tile = Cta_tile_; + // Kernel_traits. + using Kernel_traits = Kernel_traits_; + // The Base class. + using Base = Softmax_gmma_base; + // The accumulators. + using Accumulator = typename Base::Accumulator; + // The Mma tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // for HGMMA_FP16, there are 2 elements per RF for ACC. + enum { ELTS_PER_THREAD = 2 }; + + // for Hopper HGMMA, each row is held within 4 threads. + enum { THREADS_PER_ROW = 4 }; + + // The number of rows accessed by each thread. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const& scale_f = reinterpret_cast(this->params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + int elt_row = ROWS_PER_THREAD * mi + row_idx; + int elt_col = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + + float elt0 = acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 0) * scale_f; + float elt1 = acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 1) * scale_f; + + // 1.0f / softcapping_scale has been fused to scale_bmm1. + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + elt0 = this->params_softcapping_scale_bmm1_ * __tanhf(elt0); + elt1 = this->params_softcapping_scale_bmm1_ * __tanhf(elt1); + } + + this->elt_[elt_row][elt_col + 0] = elt0; + this->elt_[elt_row][elt_col + 1] = elt1; + + } // row_idx + } // col_idx + } // ni + } // mi + } + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally + int elt_row = ROWS_PER_THREAD * mi + row_idx; + int elt_col = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + float elt0 = this->elt_[elt_row][elt_col + 0]; + float elt1 = this->elt_[elt_row][elt_col + 1]; + + acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 0) = elt0; + acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 1) = elt1; + } // row_idx + } // col_idx + } // ni + } // m + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +// we know the instruction shape is 64xNx16 +// Thus for input A matrix, it is of size 64x16 per warpgroup. +// Thus, each threads access 2 rows and 4 columns. contiguous 2 columns are held by 1 RF. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile_, Kernel_traits_> + : public Softmax_gmma_base< + fmha::Hopper_hgmma_bf16_traits, Cta_tile_, + Kernel_traits_, Cta_tile_::WARPS_N> { + // The traits. + using Traits = fmha::Hopper_hgmma_bf16_traits; + // Cta_tile. + using Cta_tile = Cta_tile_; + // Kernel_traits. + using Kernel_traits = Kernel_traits_; + // The Base class. + using Base = Softmax_gmma_base; + // The accumulators. + using Accumulator = typename Base::Accumulator; + // The Mma tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // for HGMMA_FP16, there are 2 elements per RF for ACC. + enum { ELTS_PER_THREAD = 2 }; + + // for Hopper HGMMA, each row is held within 4 threads. + enum { THREADS_PER_ROW = 4 }; + + // The number of rows accessed by each thread. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + // The number of columns access by each thread. + // Note there are 2 elements per reg. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // Use BMM1 softcapping scale or not. + enum { ENABLE_BMM1_SOFTCAPPING_SCALE = Kernel_traits::ENABLE_BMM1_SOFTCAPPING_SCALE }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx) {} + + // Convert from FP16 fragments to floats. + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const& scale_f = reinterpret_cast(this->params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + int elt_row = ROWS_PER_THREAD * mi + row_idx; + int elt_col = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + + float elt0 = acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 0) * scale_f; + float elt1 = acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 1) * scale_f; + + if constexpr (ENABLE_BMM1_SOFTCAPPING_SCALE) { + elt0 = this->params_softcapping_scale_bmm1_ * __tanhf(elt0); + elt1 = this->params_softcapping_scale_bmm1_ * __tanhf(elt1); + } + + this->elt_[elt_row][elt_col + 0] = elt0; + this->elt_[elt_row][elt_col + 1] = elt1; + + } // row_idx + } // col_idx + } // ni + } // mi + } + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int col_idx = 0; col_idx < COLS_PER_THREAD; ++col_idx) { +#pragma unroll + for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { + // the order of the acc rf is we traverse vertically first + // then we traverse horizontally. + int elt_row = ROWS_PER_THREAD * mi + row_idx; + int elt_col = COLS_PER_THREAD * ELTS_PER_THREAD * ni + col_idx * ELTS_PER_THREAD; + float elt0 = this->elt_[elt_row][elt_col + 0]; + float elt1 = this->elt_[elt_row][elt_col + 1]; + + acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 0) = elt0; + acc[mi][ni].elt(col_idx * 2 * ROWS_PER_THREAD + 2 * row_idx + 1) = elt1; + } // row_idx + } // col_idx + } // ni + } // m + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { +// we know the instruction shape is 64xNx16 +// Thus for input A matrix, it is of size 64x16 per warpgroup. +// Thus, each threads access 2 rows and 4 columns. contiguous 2 columns are held by 1 RF. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 4 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float2_to_bf16_x2(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_to_bf16_x2(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_to_bf16_x2(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_to_bf16_x2(tmp_12, tmp_13); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax_gmma_32bit_8bit_base + : public Softmax_gmma_base { + // The Base class. + using Base = Softmax_gmma_base; + // The accumulators. + using Accumulator = typename Base::Accumulator; + // The Mma tile. + using Mma_tile = typename Base::Mma_tile; + + // The number of MMAs in M/N dimensions. + enum { MMAS_M = Mma_tile::MMAS_M }; + + enum { MMAS_N = Mma_tile::MMAS_N }; + + // TODO these should be general. + // Two elts per thread per acc core matrix. + enum { ELTS_PER_THREAD = 2 }; + + // Number of threads per row of the acc core matrix. + enum { THREADS_PER_ROW = 4 }; + + // The number of rows accessed by each thread per GMMA. + enum { + ROWS_PER_THREAD = + Traits::GMMA_M / (Cta_tile::THREADS_PER_WARP / THREADS_PER_ROW) / Cta_tile::WARPS_M + }; + + // The number of columns access by each thread. + enum { COLS_PER_THREAD = Traits::GMMA_N / THREADS_PER_ROW / ELTS_PER_THREAD }; + + // Check the expected number of accumulator elements. + static_assert(Accumulator::NUM_ELTS == COLS_PER_THREAD * ROWS_PER_THREAD * ELTS_PER_THREAD); + + // Ctor. + template + inline __device__ Softmax_gmma_32bit_8bit_base(Params const& params, void* smem, int bidb, + int tidx) + : Base(params, smem, bidb, tidx) {} + + inline __device__ void unpack(Accumulator const (&acc)[MMAS_M][MMAS_N]) { + float const scalef = reinterpret_cast(this->params_scale_bmm1_); +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < COLS_PER_THREAD; ++ii) { + float tmp_00 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 0 * ELTS_PER_THREAD + 0) * + scalef; + float tmp_01 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 0 * ELTS_PER_THREAD + 1) * + scalef; + float tmp_10 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 1 * ELTS_PER_THREAD + 0) * + scalef; + float tmp_11 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 1 * ELTS_PER_THREAD + 1) * + scalef; + int n_offset = ni * COLS_PER_THREAD * ELTS_PER_THREAD + ii * ELTS_PER_THREAD; + this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 0] = tmp_00; + this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 1] = tmp_01; + this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 0] = tmp_10; + this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 1] = tmp_11; + } // ii + } // ni + } // mi + } + + inline __device__ void unpack_noscale(Accumulator const (&acc)[MMAS_M][MMAS_N]) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < COLS_PER_THREAD; ++ii) { + float tmp_00 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 0 * ELTS_PER_THREAD + 0); + float tmp_01 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 0 * ELTS_PER_THREAD + 1); + float tmp_10 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 1 * ELTS_PER_THREAD + 0); + float tmp_11 = + acc[mi][ni].elt(ii * ROWS_PER_THREAD * ELTS_PER_THREAD + 1 * ELTS_PER_THREAD + 1); + int n_offset = ni * COLS_PER_THREAD * ELTS_PER_THREAD + ii * ELTS_PER_THREAD; + this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 0] = tmp_00; + this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 1] = tmp_01; + this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 0] = tmp_10; + this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 1] = tmp_11; + } // ii + } // ni + } // mi + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile, Kernel_traits> + : public Softmax_gmma_32bit_8bit_base< + fmha::Hopper_qgmma_e4m3_fp32_traits, + Cta_tile, Kernel_traits> { + // The traits. + using Traits = fmha::Hopper_qgmma_e4m3_fp32_traits; + // The Base class. + using Base = Softmax_gmma_32bit_8bit_base; + + using Accumulator = typename Base::Accumulator; + + enum { + MMAS_M = Base::MMAS_M, + MMAS_N = Base::MMAS_N, + ROWS_PER_THREAD = Base::ROWS_PER_THREAD, + COLS_PER_THREAD = Base::COLS_PER_THREAD, + ELTS_PER_THREAD = Base::ELTS_PER_THREAD, + }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(this->params_scale_softmax_); + + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < COLS_PER_THREAD; ++ii) { + int row = mi * ROWS_PER_THREAD; + int col = ni * COLS_PER_THREAD * ELTS_PER_THREAD + ii * ELTS_PER_THREAD; + float tmp_00 = this->elt_[row + 0][col + 0] * scale; + float tmp_01 = this->elt_[row + 0][col + 1] * scale; + float tmp_10 = this->elt_[row + 1][col + 0] * scale; + float tmp_11 = this->elt_[row + 1][col + 1] * scale; + + int elt_idx = ii * ROWS_PER_THREAD * ELTS_PER_THREAD; + acc[mi][ni].elt(elt_idx + 0 * ELTS_PER_THREAD + 0) = tmp_00; + acc[mi][ni].elt(elt_idx + 0 * ELTS_PER_THREAD + 1) = tmp_01; + acc[mi][ni].elt(elt_idx + 1 * ELTS_PER_THREAD + 0) = tmp_10; + acc[mi][ni].elt(elt_idx + 1 * ELTS_PER_THREAD + 1) = tmp_11; + } // ii + } // ni + } // mi + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(M == 1); + static_assert(Fragment_a::NUM_REGS == 4); + static_assert(Fragment_a::NUM_ELTS == 16); + // Acc per warp: 16 x 256 FP32 + // A is 8 times(in K) 16 x 32 FP8, i.e. 4 registers per thread. + + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD % 8 == 0); + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD == K * Fragment_a::NUM_ELTS / 2); + + float const scale = reinterpret_cast(this->params_scale_softmax_); + +// The canonical layout in K should be R0: [0,1,2,3] R2: [16,17,18,19] +// Note below that this is not possible with the register layout of the accumulator. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 8 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][8 * ki + 0] * scale; // + 0 + float tmp_01 = this->elt_[2 * mi + 0][8 * ki + 1] * scale; // + 1 + float tmp_02 = this->elt_[2 * mi + 0][8 * ki + 2] * scale; // + 8 + float tmp_03 = this->elt_[2 * mi + 0][8 * ki + 3] * scale; // + 9 + float tmp_04 = this->elt_[2 * mi + 0][8 * ki + 4] * scale; // +16 + float tmp_05 = this->elt_[2 * mi + 0][8 * ki + 5] * scale; // +17 + float tmp_06 = this->elt_[2 * mi + 0][8 * ki + 6] * scale; // +24 + float tmp_07 = this->elt_[2 * mi + 0][8 * ki + 7] * scale; // +25 + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][8 * ki + 0] * scale; // + 0 + float tmp_11 = this->elt_[2 * mi + 1][8 * ki + 1] * scale; // + 1 + float tmp_12 = this->elt_[2 * mi + 1][8 * ki + 2] * scale; // + 8 + float tmp_13 = this->elt_[2 * mi + 1][8 * ki + 3] * scale; // + 9 + float tmp_14 = this->elt_[2 * mi + 1][8 * ki + 4] * scale; // +16 + float tmp_15 = this->elt_[2 * mi + 1][8 * ki + 5] * scale; // +17 + float tmp_16 = this->elt_[2 * mi + 1][8 * ki + 6] * scale; // +24 + float tmp_17 = this->elt_[2 * mi + 1][8 * ki + 7] * scale; // +25 + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float4_to_fp8x4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = fmha::float4_to_fp8x4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[ki][mi].reg(2) = fmha::float4_to_fp8x4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[ki][mi].reg(3) = fmha::float4_to_fp8x4(tmp_14, tmp_15, tmp_16, tmp_17); + } + } + } + + uint32_t const params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax, + Cta_tile, Kernel_traits> + : public Softmax_gmma_32bit_8bit_base< + fmha::Hopper_igmma_int8_int32_traits, + Cta_tile, Kernel_traits> { + // The traits. + using Traits = fmha::Hopper_igmma_int8_int32_traits; + + // The Base class. + using Base = Softmax_gmma_32bit_8bit_base; + + using Accumulator = typename Base::Accumulator; + + enum { + MMAS_M = Base::MMAS_M, + MMAS_N = Base::MMAS_N, + ROWS_PER_THREAD = Base::ROWS_PER_THREAD, + COLS_PER_THREAD = Base::COLS_PER_THREAD, + ELTS_PER_THREAD = Base::ELTS_PER_THREAD, + }; + + // Ctor. + template + inline __device__ Softmax(Params const& params, void* smem, int bidb, int tidx) + : Base(params, smem, bidb, tidx), params_scale_softmax_(params.scale_softmax) {} + + // Store the tile after softmax. + template + inline __device__ void store(Gmem_tile& gmem_tile) { + float const scale = reinterpret_cast(this->params_scale_softmax_); + Accumulator acc[MMAS_M][MMAS_N]; + +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ni = 0; ni < MMAS_N; ++ni) { +#pragma unroll + for (int ii = 0; ii < COLS_PER_THREAD; ++ii) { + int n_offset = ni * COLS_PER_THREAD * ELTS_PER_THREAD + ii * ELTS_PER_THREAD; + float tmp_00 = this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 0]; + float tmp_01 = this->elt_[mi * ROWS_PER_THREAD + 0][n_offset + 1]; + float tmp_10 = this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 0]; + float tmp_11 = this->elt_[mi * ROWS_PER_THREAD + 1][n_offset + 1]; + + int elt_offset = ii * ROWS_PER_THREAD * ELTS_PER_THREAD; + acc[mi][ni].elt(elt_offset + 0 * ELTS_PER_THREAD + 0) = tmp_00 * scale; + acc[mi][ni].elt(elt_offset + 0 * ELTS_PER_THREAD + 1) = tmp_01 * scale; + acc[mi][ni].elt(elt_offset + 1 * ELTS_PER_THREAD + 0) = tmp_10 * scale; + acc[mi][ni].elt(elt_offset + 1 * ELTS_PER_THREAD + 1) = tmp_11 * scale; + } // ii + } // ni + } // mi + + // Delegate to the gmem tile to store. + gmem_tile.store(acc); + } + + // Pack the data to a fragment for the next GEMM. + template + inline __device__ void pack(Fragment_a (&dst)[K][M]) const { + static_assert(M == 1); + static_assert(Fragment_a::NUM_REGS == 4); + static_assert(Fragment_a::NUM_ELTS == 16); + // Acc per warp: 16 x 256 FP32 + // A is 8 times(in K) 16 x 32 FP8, i.e. 4 registers per thread. + + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD % 8 == 0); + static_assert(MMAS_N * COLS_PER_THREAD * ELTS_PER_THREAD == K * Fragment_a::NUM_ELTS / 2); + + float const scale = reinterpret_cast(this->params_scale_softmax_); +// The canonical layout in K should be R0: [0,1,2,3] R2: [16,17,18,19] +// Note below that this is not possible with the register layout of the accumulator. +#pragma unroll + for (int mi = 0; mi < M; ++mi) { +#pragma unroll + for (int ki = 0; ki < K; ++ki) { + // 1st row - 8 elements per row. + float tmp_00 = this->elt_[2 * mi + 0][8 * ki + 0] * scale; // + 0 + float tmp_01 = this->elt_[2 * mi + 0][8 * ki + 1] * scale; // + 1 + float tmp_02 = this->elt_[2 * mi + 0][8 * ki + 2] * scale; // + 8 + float tmp_03 = this->elt_[2 * mi + 0][8 * ki + 3] * scale; // + 9 + float tmp_04 = this->elt_[2 * mi + 0][8 * ki + 4] * scale; // +16 + float tmp_05 = this->elt_[2 * mi + 0][8 * ki + 5] * scale; // +17 + float tmp_06 = this->elt_[2 * mi + 0][8 * ki + 6] * scale; // +24 + float tmp_07 = this->elt_[2 * mi + 0][8 * ki + 7] * scale; // +25 + + // 2nd row - 4 elements per row. + float tmp_10 = this->elt_[2 * mi + 1][8 * ki + 0] * scale; // + 0 + float tmp_11 = this->elt_[2 * mi + 1][8 * ki + 1] * scale; // + 1 + float tmp_12 = this->elt_[2 * mi + 1][8 * ki + 2] * scale; // + 8 + float tmp_13 = this->elt_[2 * mi + 1][8 * ki + 3] * scale; // + 9 + float tmp_14 = this->elt_[2 * mi + 1][8 * ki + 4] * scale; // +16 + float tmp_15 = this->elt_[2 * mi + 1][8 * ki + 5] * scale; // +17 + float tmp_16 = this->elt_[2 * mi + 1][8 * ki + 6] * scale; // +24 + float tmp_17 = this->elt_[2 * mi + 1][8 * ki + 7] * scale; // +25 + + // Pack to 4 registers. + dst[ki][mi].reg(0) = fmha::float4_to_char4(tmp_00, tmp_01, tmp_02, tmp_03); + dst[ki][mi].reg(1) = fmha::float4_to_char4(tmp_10, tmp_11, tmp_12, tmp_13); + dst[ki][mi].reg(2) = fmha::float4_to_char4(tmp_04, tmp_05, tmp_06, tmp_07); + dst[ki][mi].reg(3) = fmha::float4_to_char4(tmp_14, tmp_15, tmp_16, tmp_17); + } + } + } + + uint32_t const params_scale_softmax_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The softmax normalization statistics used by flash attention (l, m) +template +struct Softmax_statistics { + // The shape of the MMA tile. + using Mma_tile = typename Traits::template Mma_tile; + + // The number of MMAs in the M dimension. + enum { MMAS_M = Mma_tile::MMAS_M }; + + // Ctor. + template + inline __device__ Softmax_statistics(Params const& params, void const* ptr, Binfo const& binfo, + int tidx) + : ptr_(reinterpret_cast(ptr)), seqlen_(binfo.actual_seqlen) { + // The decomposition of the thread index into warp/lane. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The position of the the warp in the CTA. + int warp_m = warp % Cta_tile::WARPS_M; + + // The position of the thread + token_ = warp_m * Mma_tile::M_PER_MMA + lane / 4; + + // Compute the offset to the first token of the sequence. + int64_t offset = binfo.bidb * params.h + binfo.bidh; + // Move the pointer to the correct position. + ptr_ += offset * params.lse_stride_in_bytes; + } + + // Load the bias into registers (and expand). + inline __device__ void load(int step) { +#pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + // The index of the token. + int token = token_; + // At each iteration we jump over STEPQ elements. + token += step * Cta_tile::M; + // The extra offset inside the CTA. + token += mi * Mma_tile::M_PER_MMA_PER_CTA + (ii & 0x1) * 8; + + // Fetch the value if the token is valid. + float val = 0.0f; + if (token < seqlen_) { + val = reinterpret_cast(ptr_)[token]; + } + lm_[2 * mi + ii] = val; + } + } + } + + // The pointer to the bias. + int8_t const* ptr_; + // The length of the sequence. + int const seqlen_; + // The token that this thread is loading. + int token_; + // The bias after expansion. + float lm_[MMAS_M * 2]; +}; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/traits.h b/csrc/fmha_v2/fmha/traits.h new file mode 100644 index 0000000000..bb6f4b700d --- /dev/null +++ b/csrc/fmha_v2/fmha/traits.h @@ -0,0 +1,942 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include + +#include "fmha/numeric_types.h" + +#define FMHA_DIV_UP(m, n) (((m) + (n) - 1) / (n)) + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Trait class for heuristically determining the tile sizes +template +struct Traits_tile_size; + +template +struct Traits_tile_size { + enum { + CTA_P_TILE_M = STEP, + CTA_P_TILE_N = S, + CTA_P_TILE_K = D, + CTA_O_TILE_M = CTA_P_TILE_M, + CTA_O_TILE_N = DV, + CTA_O_TILE_K = S + }; +}; + +template +struct Traits_tile_size { + enum { + CTA_P_TILE_M = STEP, + CTA_P_TILE_N = S, + // D =16: CTA_P_TILE_K=16 + // D =32: CTA_P_TILE_K=32 + // D>=64: CTA_P_TILE_K=64 + CTA_P_TILE_K = D < 32 ? 16 : (D < 64 ? 32 : 64), + CTA_O_TILE_M = CTA_P_TILE_M, + // D =512: CTA_TILE_N=256 + // D<=256: CTA_TILE_N=D + CTA_O_TILE_N = DV > 256 ? 256 : DV, + // D =512: CTA_O_TILE_K=16 + // D =256: CTA_O_TILE_K=32 + // D<=128: CTA_O_TILE_K=64 + CTA_O_TILE_K = std::max(K_PER_MMA, DV > 256 ? 16 : (DV > 128 ? 32 : 64)) + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The GPU architecture. + typename Gpu_arch, + // The number of rows in the CTA tile. + int M_, + // The number of cols in the CTA tile. + int N_, + // The number of elements in the the K dimension of the GEMM loop. + int K_, + // The number of valid cols in the CTA tile. + int VALID_N_, + // The number of valid elements in the the K dimension of the GEMM loop. + int VALID_K_, + // The number of rows of warps. + int WARPS_M_, + // The number of cols of warps. + int WARPS_N_, + // The number of warps in the K dimension of the GEMM loop. + int WARPS_K_> +struct Cta_tile_ { + enum { M = M_, N = N_, K = K_, VALID_N = VALID_N_, VALID_K = VALID_K_ }; + + // The number of warps. + enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ }; + + // The number of warps per CTA. + enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K }; + + // The number of threads per warp. + enum { THREADS_PER_WARP = Gpu_arch::THREADS_PER_WARP }; + + // The number of threads per CTA. + enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The GPU architecture. + typename Gpu_arch_, + // The type of the elements of A. + typename A_type_, + // The type of the elements of B. + typename B_type_, + // The type of the elements of C. + typename C_type_, + // The type of the elements of the accumulators. + typename Accumulator_type_, + // The type of the elements of the epilogue. + typename Epilogue_type_> +struct Traits { + // The architecture. + using Gpu_arch = Gpu_arch_; + // The data type for A elements. + using A_type = A_type_; + // The data type for B elements. + using B_type = B_type_; + // The data type for C elements. + using C_type = C_type_; + // The data type for accumulators. + using Accumulator_type = Accumulator_type_; + // The data type of the math in the epilogue. + using Epilogue_type = Epilogue_type_; + + // Create the description of the CTA tile from a configuration. + template + using Cta_tile_extd = Cta_tile_; + + // The number of bits per element of A. + enum { BITS_PER_ELEMENT_A = sizeof(A_type) * 8 }; + + // An offset in bytes for A. + static inline __host__ __device__ int64_t offset_in_bytes_a(int64_t offset) { + return offset * static_cast(sizeof(A_type)); + } + + // The number of bits per element of B. + enum { BITS_PER_ELEMENT_B = sizeof(B_type) * 8 }; + + // An offset in bytes for B. + static inline __host__ __device__ int64_t offset_in_bytes_b(int64_t offset) { + return offset * static_cast(sizeof(B_type)); + } + + // The number of bits per element of C. + enum { BITS_PER_ELEMENT_C = sizeof(C_type) * 8 }; + + // An offset in bytes for C. + static inline __host__ __device__ int64_t offset_in_bytes_c(int64_t offset) { + return offset * static_cast(sizeof(C_type)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Gpu_arch_base { + // By default, architectures have 32 threads per warp. + enum { THREADS_PER_WARP = 32 }; + + // By default, architectures do not support LDGSTS. + enum { HAS_LDGSTS = 0 }; + + // By default, architecture do not support super HMMA + enum { HAS_SUPER_HMMA = 0 }; + + // By default, architecture do not support TMA + enum { HAS_TMA = 0 }; + + // By default, architecture do not support GMMA + enum { HAS_GMMA = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Cta_tile_with_k_with_padding = typename Traits_::template Cta_tile_extd< + Cta_tile_::M, Cta_tile_::N, Next_power_of_two::VALUE, Cta_tile_::N, + Next_power_of_two::VALUE, Cta_tile_::WARPS_M, Cta_tile_::WARPS_N, + Cta_tile_::WARPS_K>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Volta : public Gpu_arch_base {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Volta_mma_tile { + // The number of elements computed with a single warp-MMA. + enum { M_PER_MMA = 16, N_PER_MMA = N_PER_MMA_, K_PER_MMA = K_PER_MMA_ }; + + // The number of elements computed with a single CTA-MMA. + enum { + M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K + }; + + // The number of MMAs needed to compute the GEMM. + enum { + MMAS_M = (Cta_tile::M + M_PER_MMA_PER_CTA - 1) / M_PER_MMA_PER_CTA, + MMAS_N = (Cta_tile::N + N_PER_MMA_PER_CTA - 1) / N_PER_MMA_PER_CTA, + MMAS_K = (Cta_tile::K + K_PER_MMA_PER_CTA - 1) / K_PER_MMA_PER_CTA + }; + + // The number of valid MMAs (for Head Size) + enum { + // tile o + VALID_MMAS_N = Div_up::VALUE, + // tile p + VALID_MMAS_K = Div_up::VALUE, + }; + + // The number of elements computed per warp. + enum { + M_PER_WARP = MMAS_M * M_PER_MMA, + N_PER_WARP = MMAS_N * N_PER_MMA, + K_PER_WARP = MMAS_K * K_PER_MMA, + }; + + // Do we enable the fast path for LDS. + enum { ENABLE_LDS_FAST_PATH = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Volta_hmma_fp16_traits + : public Traits { + // The K_PER_MMA for Volta_hmma_fp16_traits is 8. + enum { K_PER_MMA = 8 }; + + // The MMA tile. + template + using Mma_tile = Volta_mma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Volta_hmma_fp16_16x16x16_traits + : public Traits { + // The K_PER_MMA for Volta_hmma_fp16_16x16x16_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Volta_mma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Volta_imma_int8_int32_traits : public Traits { + // The K_PER_MMA for Volta_imma_int8_int32_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Volta_mma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Turing : public Gpu_arch_base {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Turing_mma_tile { + // The number of elements computed with a single warp-MMA. + enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = K_PER_MMA_ }; + + // The number of elements computed with a single CTA-MMA. + enum { + M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K + }; + + // The number of MMAs needed to compute the GEMM. + enum { + MMAS_M = Div_up::VALUE, + MMAS_N = Div_up::VALUE, + MMAS_K = Div_up::VALUE, + }; + + // The number of valid MMAs (for Head Size) + enum { + // tile o + VALID_MMAS_N = Div_up::VALUE, + // tile p + VALID_MMAS_K = Div_up::VALUE, + }; + + // The number of elements computed per warp. + enum { + M_PER_WARP = MMAS_M * M_PER_MMA, + N_PER_WARP = MMAS_N * N_PER_MMA, + K_PER_WARP = MMAS_K * K_PER_MMA, + }; + + // The distribution of threads in the output tile. + enum { + THREADS_PER_MMA_M = 8, + THREADS_PER_MMA_N = 4, + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Turing_hmma_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Turing_hmma_fp16_traits + : public Traits { + // The K_PER_MMA for Turing_hmma_fp16_traits is 8. + enum { K_PER_MMA = 8 }; + + // The MMA tile. + template + using Mma_tile = Turing_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Turing_hmma_fp32_traits : public Traits { + // The K_PER_MMA for Turing_hmma_fp32_traits is 8. + enum { K_PER_MMA = 8 }; + + // The MMA tile. + template + using Mma_tile = Turing_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Turing_imma_int8_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Turing_imma_int8_int32_traits + : public Traits { + // The K_PER_MMA for Turing_imma_int8_int32_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Turing_imma_int8_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere : public Gpu_arch_base { + // It has LDGSTS. + enum { HAS_LDGSTS = 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ampere_hmma_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere_hmma_fp16_traits + : public Traits { + // The K_PER_MMA for Ampere_hmma_fp16_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Ampere_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere_hmma_fp32_traits + : public Traits { + // The K_PER_MMA for Ampere_hmma_fp32_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Ampere_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// used for Epilogue_type = bf16_t (similar to Ampere_hmma_fp16_traits). +struct Ampere_hmma_bf16_bf16_traits + : public Traits { + // The K_PER_MMA for Ampere_hmma_bf16_bf16_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Ampere_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere_hmma_bf16_traits : public Traits { + // The K_PER_MMA for Ampere_hmma_bf16_traits is 16. + enum { K_PER_MMA = 16 }; + + // The MMA tile. + template + using Mma_tile = Ampere_hmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ampere_imma_int8_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ampere_imma_int8_int32_traits + : public Traits { + // The K_PER_MMA for Ampere_imma_int8_int32_traits is 32. + enum { K_PER_MMA = 32 }; + + // The MMA tile. + template + using Mma_tile = Ampere_imma_int8_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ada : public Gpu_arch_base { + // It has LDGSTS. + enum { HAS_LDGSTS = 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The following partial traits are mapped to Ampere_hmma_fp16_traits in fmha/kernel_traits.h. +// +// It is easier to implement setup.py this way. +struct Ada_hmma_fp16_traits {}; + +struct Ada_hmma_fp32_traits {}; + +struct Ada_imma_int8_int32_traits {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ada_qmma_fp8_tile : public Turing_mma_tile {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ada_qmma_e4m3_fp16_traits : public Traits { + // The K_PER_MMA for Ada_qmma_e4m3_fp16_traits is 32. + enum { K_PER_MMA = 32 }; + + // The MMA tile. + template + using Mma_tile = Ada_qmma_fp8_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Ada_qmma_e4m3_fp32_traits : public Traits { + // The K_PER_MMA for Ada_qmma_e4m3_fp32_traits is 32. + enum { K_PER_MMA = 32 }; + + // The MMA tile. + template + using Mma_tile = Ada_qmma_fp8_tile; + + static constexpr float SOFTMAX_FP_QUANT_SCALE = Softmax_fp_quant_scale(); + static constexpr float SOFTMAX_FP_DEQUANT_SCALE = 1.f / SOFTMAX_FP_QUANT_SCALE; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Hopper : public Gpu_arch_base { + // It has LDGSTS. + enum { HAS_LDGSTS = 1 }; + + // It has TMA. + enum { HAS_TMA = 1 }; + + // It has GMMA + enum { HAS_GMMA = 1 }; + + // for Hopper there are 4 warps per warpgroup. + enum { WARPS_PER_WARP_GROUP = 4 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper related code. +// SHOULD we move this to a different file?? +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct Hopper_cga_tile { + // The size of the CGA in terms of CTA + enum { CLUSTER_HEIGHT = HEIGHT_ }; + + enum { CLUSTER_WIDTH = WIDTH_ }; + + enum { CLUSTER_DEPTH = DEPTH_ }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template // Number of warp group along K dim +struct Hopper_cta_tile { + // GPU arch. + using Gpu_arch = Gpu_arch_; + + // The size of the CTA tile. + // TODO: support D (not power of 2) + enum { M = M_, N = N_, K = K_, VALID_N = VALID_N_, VALID_K = VALID_K_ }; + + // The number of warp groups. + enum { WARP_GROUP_M = WARP_GROUP_M_, WARP_GROUP_N = WARP_GROUP_N_, WARP_GROUP_K = WARP_GROUP_K_ }; + + // The number of warps in a warp group. + enum { + WARPS_M_PER_GROUP = 4, + WARPS_N_PER_GROUP = 1, + WARPS_K_PER_GROUP = 1, + }; + + // The number of warps in a cta. + enum { + WARPS_M = WARPS_M_PER_GROUP * WARP_GROUP_M_, + WARPS_N = WARPS_N_PER_GROUP * WARP_GROUP_N_, + WARPS_K = WARPS_K_PER_GROUP * WARP_GROUP_K_ + }; + + // The number of warps per CTA. + enum { + WARPS_PER_CTA = WARP_GROUP_M * WARP_GROUP_N * WARP_GROUP_K * Gpu_arch::WARPS_PER_WARP_GROUP + }; + + // The number of warps per warpgroup. + enum { WARPS_PER_WARP_GROUP = Gpu_arch::WARPS_PER_WARP_GROUP }; + + // The number of threads per warp. + enum { THREADS_PER_WARP = Gpu_arch::THREADS_PER_WARP }; + + // the number of threads per warpgroup. + enum { THREADS_PER_WARP_GROUP = THREADS_PER_WARP * WARPS_PER_WARP_GROUP }; + + // The number of threads per CTA. + enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP }; + + enum { GROUPS_M = 1 }; + + enum { GROUPS_N = 1 }; + + enum { GROUPS_K = 1 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hopper_gmma_tile { + // The number of elements computed with a single warp group mma. + enum { M_PER_MMA = GMMA_M, N_PER_MMA = GMMA_N, K_PER_MMA = GMMA_K }; + + // The number of warp groups. + enum { + NUM_WARP_GROUPS = Cta_tile::WARP_GROUP_M * Cta_tile::WARP_GROUP_N * Cta_tile::WARP_GROUP_K + }; + + // The number of elements computed with a single CTA-MMA. + enum { + M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARP_GROUP_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARP_GROUP_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARP_GROUP_K + }; + + // The number of MMAs needed to compute the GEMM. + enum { + MMAS_M = (Cta_tile::M + M_PER_MMA_PER_CTA - 1) / M_PER_MMA_PER_CTA, + MMAS_N = (Cta_tile::N + N_PER_MMA_PER_CTA - 1) / N_PER_MMA_PER_CTA, + MMAS_K = (Cta_tile::K + K_PER_MMA_PER_CTA - 1) / K_PER_MMA_PER_CTA, + }; + + // The number of valid MMAs (for Head Size) + enum { + // tile o + VALID_MMAS_N = Div_up::VALUE, + // tile p + VALID_MMAS_K = Div_up::VALUE, + }; + + // The number of elements computed per warp group. + enum { + M_PER_WARP_GROUP = MMAS_M * M_PER_MMA, + N_PER_WARP_GROUP = MMAS_N * N_PER_MMA, + K_PER_WARP_GROUP = MMAS_K * K_PER_MMA, + }; + + // the size of GMMA group, which is GMMA_M x GMMA_N x Kblock. + enum { + M_PER_GMMA_GROUP = GMMA_M, + N_PER_GMMA_GROUP = GMMA_N, + K_PER_GMMA_GROUP = Cta_tile::K, + }; + + // The distribution of threads in the output tile. + // TODO + enum { + THREADS_PER_MMA_M = 8, + THREADS_PER_MMA_N = 4, + }; + + // The number of core matrices per GMMA. + enum { + CORES_M_PER_GROUP = 8 * Cta_tile::WARPS_M_PER_GROUP, + CORES_N_PER_GROUP = 8 * Cta_tile::WARPS_N_PER_GROUP, + CORES_M = GMMA_M / CORES_M_PER_GROUP, + CORES_N = GMMA_N / CORES_N_PER_GROUP, + }; + + // The number of logical rows/cols per thread. + enum { + // A thread owns 1 row per core matrix. + ROWS_PER_THREAD = CORES_M, + // A thread owns 2 col per core matrix. + COLS_PER_THREAD = CORES_N * 2, + }; + + static_assert(ROWS_PER_THREAD == 2); + static_assert(COLS_PER_THREAD == GMMA_N / 4); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class Hopper_instructions { + HGMMA_FP16, + HGMMA_BF16, + HGMMA_FP32, + IGMMA_INT32, + QGMMA_E4M3_FP32, + QGMMA_E5M2_FP32, + QGMMA_E4M3_FP16, + QGMMA_E5M2_FP16 +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper HGMMA FP16 Traits +template +struct Hopper_hgmma_fp16_traits + : public Traits { + // The GMMA shape. + enum { GMMA_M = GMMA_M_, GMMA_N = GMMA_N_, GMMA_K = 16 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirements. + static_assert(GMMA_K == 16, "GMMA K must be 16; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The MMA tile. + template + using Mma_tile = Hopper_gmma_tile; + + // The handle to differentiate instructions. + static constexpr fmha::Hopper_instructions HOPPER_INSTRUCTION = + fmha::Hopper_instructions::HGMMA_FP16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper HGMMA FP32 Traits +template +struct Hopper_hgmma_fp32_traits + : public Traits { + // The GMMA shape. + enum { GMMA_M = GMMA_M_, GMMA_N = GMMA_N_, GMMA_K = 16 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirements. + static_assert(GMMA_K == 16, "GMMA K must be 16; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The MMA tile. + template + using Mma_tile = Hopper_gmma_tile; + + // The handle to differentiate instructions. + static constexpr fmha::Hopper_instructions HOPPER_INSTRUCTION = + fmha::Hopper_instructions::HGMMA_FP32; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper BF16 HGMMA Traits +template +struct Hopper_hgmma_bf16_traits : public Traits { + // The GMMA shape. + enum { GMMA_M = GMMA_M_, GMMA_N = GMMA_N_, GMMA_K = 16 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirements. + static_assert(GMMA_K == 16, "GMMA K must be 16; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The MMA tile. + template + using Mma_tile = Hopper_gmma_tile; + + // The handle to differentiate instructions. + static constexpr fmha::Hopper_instructions HOPPER_INSTRUCTION = + fmha::Hopper_instructions::HGMMA_BF16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper IGMMA Traits +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hopper_igmma_int8_int32_traits + : public Traits { + using Base = Traits; + + // The GMMA shape + enum { GMMA_M = GMMA_M_ }; + + enum { GMMA_N = GMMA_N_ }; + + enum { GMMA_K = 32 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirement + static_assert(GMMA_K == 32, "GMMA K must be 32; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The MMA tile. + template + using Mma_tile = Hopper_gmma_tile; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Hopper QGMMA Traits +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hopper_qgmma_fp8_fp32_traits + : public Traits { + using Base = Traits; + + using Input_type_A = Input_type_A_; + using Input_type_B = Input_type_B_; + using Output_type = Output_type_; + + // The GMMA shape + enum { GMMA_M = GMMA_M_ }; + + enum { GMMA_N = GMMA_N_ }; + + enum { GMMA_K = 32 }; + + // is A operand in RF for GMMA? + static constexpr bool GMMA_A_RF = GMMA_A_RF_; + + // is B operand in RF for GMMA? + static constexpr bool GMMA_B_RF = GMMA_B_RF_; + + // GMMA shape has certain requirement + static_assert(GMMA_K == 32, "GMMA K must be 32; this might change"); + static_assert(GMMA_M == 64, "GMMA M must be 64; this might change"); + static_assert(GMMA_N % 8 == 0, "GMMA N must be multiple of 8; this might change"); + static_assert(GMMA_N <= 256, "GMMA N must be no larger than 256; this might change"); + + // GMMA does not allow both operands coming from RF. + static_assert((GMMA_A_RF && GMMA_B_RF) != true, + "GMMA does not allow both operands coming from RF."); + + // The Cta tile. + template + using Cta_tile = Hopper_cta_tile; + + // The Cta tile. + template + using Cta_padded_tile = + Hopper_cta_tile; + + // The CGA Tile + template + using Cga_tile = Hopper_cga_tile; + + // The XMMA tile. + template + using Mma_tile = Hopper_gmma_tile; + + // Used by low precision floating point types (e4m3, e5m2, etc.) + static constexpr float SOFTMAX_FP_QUANT_SCALE = Softmax_fp_quant_scale(); + static constexpr float SOFTMAX_FP_DEQUANT_SCALE = 1.f / SOFTMAX_FP_QUANT_SCALE; +}; + +template +using Hopper_qgmma_e4m3_fp32_traits = + Hopper_qgmma_fp8_fp32_traits; + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/utils.h b/csrc/fmha_v2/fmha/utils.h new file mode 100644 index 0000000000..f65d2fe661 --- /dev/null +++ b/csrc/fmha_v2/fmha/utils.h @@ -0,0 +1,2355 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include +#include + +#if defined(__CLANGD__) +#include <__clang_cuda_builtin_vars.h> +#include <__clang_cuda_math.h> +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +// include warpgroup related instructions, used by SM90. +#include +// include gmma related instructions, used by SM90. +#include +// include tma related instructions, used by SM90. +#include + +#include "fmha/numeric_types.h" + +#define FP32_I2F_MAGIC_NUMBER 12582912.f +#define FP32_I2F_MAGIC_NUMBER_HEX 0x4b400000 + +extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void* ptr); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace introspection { + +template +struct Unpack; + +template +struct Unpack { + // if we simply static_assert(false) then compiler will not emit template params upon failure + static_assert(N < INT_MIN, ""); + using Type = std::integral_constant; +}; + +template +struct Unpack { + using Type = Unpack; + using Unpack_first = typename Unpack::Type; + using Unpack_remaining = typename Unpack::Type; +}; + +} // namespace introspection + +// Example usage: +// +// Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER> foo; +// +// or +// +// Inspect_ns<(int)USE_LDGSTS_, PRED_REGS, (int)IS_HOPPER>{}.foo(); +// +// Output by nvcc: +// +// ./src/fmha/gmem_tile_qkv_packed.h(70): error: static assertion failed with "" +// detected during: +// instantiation of class "fmha::v2::Unpack [with N=1]" +// (77): here +// instantiation of class "fmha::v2::Unpack [with N=1, Ns=<2, 0>]" +// (84): here +// instantiation of class "fmha::v2::Inspect_ns [with Ns=<1, 2, 0>]" +// (143): here +template +struct Inspect_ns { + using Type = typename introspection::Unpack::Type; +}; + +// Can be used alongside with static_assert() to figure out the conditions when assertion failed +// Example: +// +// Cond_inspect_ns< (int)ROWS >= (int)ROWS_PER_LDG, ROWS, ROWS_PER_LDG> foo; +// +// Output by nvcc (when condition is not met): +// +// ./src/fmha/utils.h(163): error: static assertion failed with "" +// detected during: +// instantiation of class "Cond_inspect_ns [with COND=false, Ns=<32, +// 64>]" +template +struct Cond_inspect_ns { + static_assert(COND, ""); +}; + +// Example: +// +// Inspect_type{}.foo(); +// +// or +// +// Inspect_type foo; +// +// Output by nvcc: +// +// ./src/fmha/utils.h(189): error: class "fmha::Ampere_hmma_tile, 16>" has no member "Dummy" +// detected during: +// instantiation of class "Inspect_type [with +// T=fmha::Ampere_hmma_tile, 16>]" +template +struct Inspect_type { + // Purposefully trigger error by referencing non-existent T::Dummy + using Dummy = typename T::Dummy; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Row { + static constexpr bool COL = false; + static constexpr bool ROW = true; +}; + +struct Col { + static constexpr bool COL = true; + static constexpr bool ROW = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Round_up { + enum { VALUE = (M + N - 1) / N * N }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Tile_nhw { + enum { N = N_, H = H_, W = W_ }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Next_power_of_two {}; + +template +struct Next_power_of_two { + enum { VALUE = M }; +}; + +template <> +struct Next_power_of_two<3, false> { + enum { VALUE = 4 }; +}; + +template <> +struct Next_power_of_two<5, false> { + enum { VALUE = 8 }; +}; + +template <> +struct Next_power_of_two<6, false> { + enum { VALUE = 8 }; +}; + +template <> +struct Next_power_of_two<7, false> { + enum { VALUE = 8 }; +}; + +template <> +struct Next_power_of_two<9, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<10, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<11, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<12, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<13, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<14, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<15, false> { + enum { VALUE = 16 }; +}; + +template <> +struct Next_power_of_two<24, false> { + enum { VALUE = 32 }; +}; + +template <> +struct Next_power_of_two<40, false> { + enum { VALUE = 64 }; +}; + +template <> +struct Next_power_of_two<48, false> { + enum { VALUE = 64 }; +}; + +template <> +struct Next_power_of_two<72, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<80, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<96, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<104, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<112, false> { + enum { VALUE = 128 }; +}; + +template <> +struct Next_power_of_two<144, false> { + enum { VALUE = 256 }; +}; + +template <> +struct Next_power_of_two<160, false> { + enum { VALUE = 256 }; +}; + +template <> +struct Next_power_of_two<192, false> { + enum { VALUE = 256 }; +}; + +template <> +struct Next_power_of_two<576, false> { + enum { VALUE = 1024 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Prev_power_of_two {}; + +template +struct Prev_power_of_two { + enum { VALUE = N }; +}; + +template <> +struct Prev_power_of_two<3, false> { + enum { VALUE = 2 }; +}; + +template <> +struct Prev_power_of_two<5, false> { + enum { VALUE = 4 }; +}; + +template <> +struct Prev_power_of_two<6, false> { + enum { VALUE = 4 }; +}; + +template <> +struct Prev_power_of_two<7, false> { + enum { VALUE = 4 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_skew { + // The size of a transaction. + enum { BYTES_PER_TRX = 128 }; + + // The remainder of the row without skew. + enum { REMAINDER = BYTES_PER_ROW % BYTES_PER_TRX }; + + // The value. + enum { VALUE = REMAINDER <= SKEW ? SKEW - REMAINDER : BYTES_PER_TRX + SKEW - REMAINDER }; + + // Make sure the math works ;) + static_assert((BYTES_PER_ROW + VALUE) % BYTES_PER_TRX == SKEW, ""); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_skew { + // No skew! + enum { VALUE = 0 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Div_up { + enum { VALUE = (M + N - 1) / N }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Max { + enum { VALUE = A >= B ? A : B }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Max_3 { + enum { VALUE = Max::VALUE, C>::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Min { + enum { VALUE = A <= B ? A : B }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Uint_from_size_in_bytes {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<1> { + using Type = uint8_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<2> { + using Type = uint16_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<4> { + using Type = uint32_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<8> { + using Type = uint2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Uint_from_size_in_bytes<16> { + using Type = uint4; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Warp_masks {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Warp_masks<8, 1, 1> { + enum { M = 0xe0, N = 0x00, K = 0x00 }; +}; + +template <> +struct Warp_masks<4, 2, 1> { + enum { M = 0x60, N = 0x80, K = 0x00 }; +}; + +template <> +struct Warp_masks<4, 1, 2> { + enum { M = 0x60, N = 0x00, K = 0x80 }; +}; + +template <> +struct Warp_masks<4, 1, 1> { + enum { M = 0x60, N = 0x00, K = 0x00 }; +}; + +template <> +struct Warp_masks<2, 4, 1> { + enum { M = 0x20, N = 0xc0, K = 0x00 }; +}; + +template <> +struct Warp_masks<2, 2, 2> { + enum { M = 0x20, N = 0x40, K = 0x80 }; +}; + +template <> +struct Warp_masks<2, 2, 1> { + enum { M = 0x20, N = 0x40, K = 0x00 }; +}; + +template <> +struct Warp_masks<2, 1, 2> { + enum { M = 0x20, N = 0x00, K = 0x40 }; +}; + +template <> +struct Warp_masks<2, 1, 1> { + enum { M = 0x20, N = 0x00, K = 0x00 }; +}; + +template <> +struct Warp_masks<1, 8, 1> { + enum { M = 0x00, N = 0xe0, K = 0x00 }; +}; + +template <> +struct Warp_masks<1, 4, 2> { + enum { M = 0x00, N = 0x60, K = 0x80 }; +}; + +template <> +struct Warp_masks<1, 4, 1> { + enum { M = 0x00, N = 0x60, K = 0x00 }; +}; + +template <> +struct Warp_masks<1, 2, 2> { + enum { M = 0x00, N = 0x20, K = 0x40 }; +}; + +template <> +struct Warp_masks<1, 2, 1> { + enum { M = 0x00, N = 0x20, K = 0x00 }; +}; + +template <> +struct Warp_masks<1, 1, 4> { + enum { M = 0x00, N = 0x00, K = 0x60 }; +}; + +template <> +struct Warp_masks<1, 1, 2> { + enum { M = 0x00, N = 0x00, K = 0x20 }; +}; + +template <> +struct Warp_masks<1, 1, 1> { + enum { M = 0x00, N = 0x00, K = 0x00 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) { + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int clz(int x) { + for (int i = 31; i >= 0; --i) { + if ((1 << i) & x) { + return 31 - i; + } + } + return 32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int find_log_2(int x, bool round_up = false) { + int a = 31 - clz(x); + if (round_up) { + a += (x & (x - 1)) ? 1 : 0; + } + return a; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void find_divisor(uint32_t& mul, uint32_t& shr, int x) { + assert(x != 0); + if (x == 1) { + // If dividing by 1, reduced math doesn't work because mul_coeff would need to be 2^32, + // which doesn't fit into unsigned int. the div() routine handles this special case + // separately. + mul = 0; + shr = 0; + } else { + // To express the division N/D in terms of a multiplication, what we first + // imagine is simply N*(1/D). However, 1/D will always evaluate to 0 (for D>1), + // so we need another way. There's nothing that says we have to use exactly + // the fraction 1/D; instead it could be any X/Y that reduces to 1/D (i.e., + // Y=X*D), or at least to "close enough" to it. If we pick Y that is a power + // of two, then the N*(X/Y) can be N*X followed by a right-shift by some amount. + // The power of two we should pick should be at least 2^32, because in the + // div() routine we'll use umulhi(), which returns only the upper 32 bits -- + // this being equivalent to a right-shift by 32. But we might want a higher + // power of two for better accuracy depending on the magnitude of the denominator. + // Once we've picked Y, then X [our mul_coeff value] is simply Y/D, rounding up, + // and we save shift_coeff as whatever further shift we have to do beyond + // what the umulhi() implies. + uint32_t p = 31 + find_log_2(x, true); + uint32_t m = (uint32_t)(((1ull << p) + (uint32_t)x - 1) / (uint32_t)x); + + mul = m; + shr = p - 32; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void fast_divmod(int& div, int& mod, int x, int y, uint32_t mul, uint32_t shr) { + if (y == 1) { + div = x; + mod = 0; + } else { + div = __umulhi((uint32_t)x, mul) >> shr; + mod = x - div * y; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t bfadd2(uint32_t a, uint32_t b) { + uint32_t c; + uint32_t one = 0x3f803f80; + ; + asm volatile("fma.rn.bf16x2 %0, %1, %3, %2;\n" : "=r"(c) : "r"(a), "r"(b), "r"(one)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmax2(uint32_t a, uint32_t b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela, selb;\n" + "\n" + "\t set.ge.f16x2.f16x2 sela, %1, %2;\n" + "\t set.gt.f16x2.f16x2 selb, %2, %1;\n" + "\n" + "\t mul.f16x2 %0, sela, %1;\n" + "\t fma.rn.f16x2 %0, selb, %2, %0;\n" + "}\n" + : "=r"(c) + : "r"(a), "r"(b)); +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hmax4(uint2 a, uint2 b) { + uint2 c; + c.x = hmax2(a.x, b.x); + c.y = hmax2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmax8(uint4 a, uint4 b) { + uint4 c; + c.x = hmax2(a.x, b.x); + c.y = hmax2(a.y, b.y); + c.z = hmax2(a.z, b.z); + c.w = hmax2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela, selb;\n" + "\n" + "\t set.le.f16x2.f16x2 sela, %1, %2;\n" + "\t set.lt.f16x2.f16x2 selb, %2, %1;\n" + "\n" + "\t mul.f16x2 %0, sela, %1;\n" + "\t fma.rn.f16x2 %0, selb, %2, %0;\n" + "}\n" + : "=r"(c) + : "r"(a), "r"(b)); +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t bfmul2(uint32_t a, uint32_t b) { + uint32_t c; + asm("{.reg .b32 c;\n" + " mov.b32 c, 0x80008000U;\n" + " fma.rn.bf16x2 %0,%1,%2,c;}\n" + : "=r"(c) + : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hmul4(uint2 a, uint2 b) { + uint2 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmul8(uint4 a, uint4 b) { + uint4 c; + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + c.z = hmul2(a.z, b.z); + c.w = hmul2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { + uint4 c; + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint32_t mul2(uint32_t a, uint32_t b) { + return hmul2(a, b); +} + +template <> +inline __device__ uint32_t mul2(uint32_t a, uint32_t b) { + return bfmul2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint4 mul8(uint32_t a, uint4 b) { + uint4 c; + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); + return c; +} + +template <> +inline __device__ uint4 mul8(uint32_t a, uint4 b) { + uint4 c; + c.x = bfmul2(a, b.x); + c.y = bfmul2(a, b.y); + c.z = bfmul2(a, b.z); + c.w = bfmul2(a, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hrelu2(uint32_t x) { + uint32_t res; + uint32_t const zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#endif + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t bfrelu2(uint32_t x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + uint32_t res; + uint32_t const zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +#endif + // not implemented yet + return x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint32_t relu2(uint32_t x) { + return hrelu2(x); +} + +template <> +inline __device__ uint32_t relu2(uint32_t x) { + return bfrelu2(x); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t habs2(uint32_t x) { + uint32_t res; + asm volatile("abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// static inline __device__ uint32_t add_bias(uint32_t a, uint32_t bias, bool relu) { +// uint32_t c; +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// if( relu ) { +// uint32_t one = 0x3c003c00u; +// asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(c) : "r"(a), "r"(one), +// "r"(bias)); +// } else { +// c = hadd2(a, bias); +// } +// #else +// c = hadd2(a, bias); +// if( relu ) { +// c = hrelu2(c); +// } +// #endif +// return c; +// } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// static inline __device__ uint2 add_bias(uint2 a, uint2 bias, bool relu) { +// uint2 dst; +// dst.x = add_bias(a.x, bias.x, relu); +// dst.y = add_bias(a.y, bias.y, relu); +// return dst; +// } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// static inline __device__ uint4 add_bias(uint4 a, uint4 bias, bool relu) { +// uint4 dst; +// dst.x = add_bias(a.x, bias.x, relu); +// dst.y = add_bias(a.y, bias.y, relu); +// dst.z = add_bias(a.z, bias.z, relu); +// dst.w = add_bias(a.w, bias.w, relu); +// return dst; +// } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clamp float +inf/-inf +static inline __device__ float satfinite(float x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860 + // bit representation of maximum value of float + uint32_t clamp_value = 0x7f7fffffu; + asm volatile("min.xorsign.abs.f32 %0, %0, %1;" : "+f"(x) : "r"(clamp_value)); + return x; +#else + // bit representation of maximum and minimum value of float + uint32_t umax = 0x7f7fffffu; + uint32_t umin = 0xff7fffffu; + float out; + asm volatile("min.f32 %0, %1, %2;" : "=f"(out) : "f"(x), "r"(umax)); + asm volatile("max.f32 %0, %0, %1;" : "+f"(out) : "r"(umin)); + return out; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clamp half2 +inf/-inf +static inline __device__ uint32_t satfinite_h2(uint32_t h2) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 860 + uint32_t out, clamp_value; + clamp_value = 0x7bff7bffu; + asm volatile("min.xorsign.abs.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(clamp_value)); + return out; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800 + // bit representation of maximum and minimum value of half2 + uint32_t umax = 0x7bff7bffu; + uint32_t umin = 0xfbfffbffu; + uint32_t out; + asm volatile("min.f16x2 %0, %1, %2;" : "=r"(out) : "r"(h2), "r"(umax)); + asm volatile("max.f16x2 %0, %0, %1;" : "+r"(out) : "r"(umin)); + return out; +#else + // Take the absolute value of h2. It should map to |Rx| in SASS. + uint32_t p2; + asm volatile("abs.f16x2 %0, %1;" : "=r"(p2) : "r"(h2)); + + // Compute a mask for each fp16: 0xffff if +INF and 0x0000 otherwise. + uint32_t inf2 = 0x7c007c00u; + uint32_t mask; + asm volatile("set.eq.u32.f16x2 %0, %1, %2;" : "=r"(mask) : "r"(p2), "r"(inf2)); + + // Recreate the new value. 0x7bff is the max value for FP16. + p2 = (~mask & p2) | (mask & 0x7bff7bff); + + // Simply re-add the sign and we're done. + return p2 | (h2 & 0x80008000); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +static inline __device__ T clamp(T x, T lb, T ub) { + return x < lb ? lb : (x > ub ? ub : x); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float custom_exp2f(float x, float scale, float scaled_max) { + float d1, d2; + asm("fma.rz.ftz.f32 %0, %1, %2, %3;" : "=f"(d1) : "f"(x), "f"(scale), "f"(-scaled_max)); + asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(d2) : "f"(d1)); + return d2; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t clamp_to_zero(uint16_t x) { + uint16_t mask; + asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); + return mask & x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t float_to_half(float f) { + uint16_t h; + asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); + return h; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ bf16_t float_to_bf16(float f) { return __float2bfloat16(f); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_half2(float a, float b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); +#else + uint16_t lo = float_to_half(a); + uint16_t hi = float_to_half(b); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_bf16_x2(float a, float b) { + uint32_t c; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); +#else + uint16_t* px = reinterpret_cast(&a); + uint16_t* py = reinterpret_cast(&b); + uint16_t value = px[1]; + uint16_t value2 = py[1]; + + if (px[0] == 0x8000) { + if ((value & 0x1) == 1) value++; + } else if (px[0] > 0x8000) { + value++; + } + + if (py[0] == 0x8000) { + if ((value2 & 0x1) == 1) value2++; + } else if (py[0] > 0x8000) { + value2++; + } + + uint32_t high = reinterpret_cast(value2); + c = (high << 16) | value; +#endif + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint32_t float2_to_16bit_2(float a, float b) { + return float2_to_half2(a, b); +} + +template <> +inline __device__ uint32_t float2_to_16bit_2(float a, float b) { + return float2_to_bf16_x2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float_to_half2(float a) { return float2_to_half2(a, a); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float2_to_half2(float2 const& f) { + return float2_to_half2(f.x, f.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t float_to_bf16_2(float a) { return float2_to_bf16_x2(a, a); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) { + uint2 d; + d.x = float2_to_half2(x, y); + d.y = float2_to_half2(z, w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint2 float4_to_16bit_x4(float x, float y, float z, float w) { + uint2 d; + d.x = float2_to_half2(x, y); + d.y = float2_to_half2(z, w); + return d; +} + +template <> +inline __device__ uint2 float4_to_16bit_x4(float x, float y, float z, float w) { + uint2 d; + d.x = float2_to_bf16_x2(x, y); + d.y = float2_to_bf16_x2(z, w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); +#else + d = hrelu2(hfma2(a, b, c)); +#endif + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t h0_h0(uint32_t x) { + uint32_t y; + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" + : "=r"(y) + : "r"(x)); + return y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float h0_to_float(uint32_t h2) { + float f; + asm volatile( + "{\n" + ".reg .f16 lo, hi;\n" + "mov.b32 {lo, hi}, %1;\n" + "cvt.f32.f16 %0, lo;\n" + "}\n" + : "=f"(f) + : "r"(h2)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t h1_h1(uint32_t x) { + uint32_t y; + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" + : "=r"(y) + : "r"(x)); + return y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { + uint16_t d; + asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { return hadd2(a, b); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hadd4(uint2 a, uint2 b) { + uint2 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint2 hadd(uint2 a, uint2 b) { return hadd4(a, b); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hadd8(uint4 a, uint4 b) { + uint4 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + c.z = hadd2(a.z, b.z); + c.w = hadd2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ uint4 add8(uint4 a, uint4 b) { + return hadd8(a, b); +} + +template <> +inline __device__ uint4 add8(uint4 a, uint4 b) { + uint4 c; + c.x = bfadd2(a.x, b.x); + c.y = bfadd2(a.y, b.y); + c.z = bfadd2(a.z, b.z); + c.w = bfadd2(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 fadd4(uint4 a, uint4 b) { + float4 c; + c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); + c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); + c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); + c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); + return reinterpret_cast(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint4 hadd(uint4 a, uint4 b) { return hadd8(a, b); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float bf16_to_float(uint16_t h) { + float f; + asm volatile("mov.b32 %0, {0, %1};\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float2 half2_to_float2(uint32_t x) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float2 bf16_2_to_float2(uint32_t x) { + float2 res; + asm volatile( + "{\n" + " .reg .b16 lo, hi;\n" + " mov.b32 {lo, hi}, %2;\n" + " mov.b32 %0, {0, lo};\n" + " mov.b32 %1, {0, hi};\n" + "}\n" + : "=f"(res.x), "=f"(res.y) + : "r"(x)); + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Template function to support both half and bfloat16 +template +inline __device__ float2 convert_from_16bit_2(uint32_t x) { + return half2_to_float2(x); +} + +template <> +inline __device__ float2 convert_from_16bit_2(uint32_t x) { + return bf16_2_to_float2(x); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void half2_to_float2(float& x, float& y, uint32_t h) { + float2 tmp = half2_to_float2(h); + x = tmp.x; + y = tmp.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { + uint16_t d; + asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { + uint16_t d; + asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two half2's or bf162's into float, then take their dot product. +template +inline __device__ float fma2_in_float(uint32_t const a, uint32_t const b) { + float2 af = fmha::convert_from_16bit_2(a); + float2 bf = fmha::convert_from_16bit_2(b); + return af.x * bf.x + af.y * bf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two vectors of 8 half's or bf16's into float, then take their dot product. +template +inline __device__ float fma8_in_float(uint4 const a, uint4 const b) { + float sum; + sum = fmha::fma2_in_float(a.x, b.x); + sum += fmha::fma2_in_float(a.y, b.y); + sum += fmha::fma2_in_float(a.z, b.z); + sum += fmha::fma2_in_float(a.w, b.w); + return sum; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint16_t& dst) { dst = uint16_t(0); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint32_t& dst) { dst = 0u; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint2& dst) { dst = make_uint2(0u, 0u); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void clear(uint4& dst) { dst = make_uint4(0u, 0u, 0u, 0u); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// P R E D I C A T E P A C K I N G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Compute_number_of_pred_regs { + enum { VALUE = Div_up::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void pack_predicates(uint32_t (&preds)[M], uint32_t const (&p)[N]) { + // Make sure the values match. + static_assert(Compute_number_of_pred_regs::VALUE == M, ""); + + // The number of complete steps (where we use all the predicates in a byte). + enum { COMPLETE_BYTES = N / PREDS_PER_BYTE }; + + // Make sure we allocated enough predicate registers. + static_assert(Div_up::VALUE <= M, ""); + + // The remainder. + enum { REMAINDER = N - COMPLETE_BYTES * PREDS_PER_BYTE }; + + // Make sure we got the math right and the remainder is between 0 and 3. + static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); + + // The mask to extract the predicates. + enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; + + // Run complete steps. +#pragma unroll + for (int ii = 0; ii < M; ++ii) { + // The number of complete bytes for that register. Be careful it can be > than 4 ;) + int const COMPLETE = (N - ii * PREDS_PER_REG) / PREDS_PER_BYTE; + + // Pack the predicates in a register. + uint32_t reg = 0u; +#pragma unroll + for (int jj = 0; jj < 4; ++jj) { + // Early exit. + if (jj >= COMPLETE) { + break; + } + + // Prepare the array of predicates. + bool tmp[PREDS_PER_BYTE]; +#pragma unroll + for (int kk = 0; kk < PREDS_PER_BYTE; ++kk) { + tmp[kk] = p[ii * PREDS_PER_REG + jj * PREDS_PER_BYTE + kk] != 0; + } + + // Store the predicates. +#pragma unroll + for (int kk = 0; kk < PREDS_PER_BYTE; ++kk) { + if (tmp[kk]) { + reg |= 1u << (jj * 8 + kk); + } + } + } + + // Skip the rest of the code if we do not have a remainder. + if (COMPLETE < 4 && REMAINDER > 0) { + // The mask to extract the predicates. + enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; + + // Prepare the array of predicates. + bool tmp[PREDS_PER_BYTE]; +#pragma unroll + for (int jj = 0; jj < REMAINDER; ++jj) { + tmp[jj] = p[COMPLETE_BYTES * PREDS_PER_BYTE + jj] != 0; + } + + // Store the predicates. +#pragma unroll + for (int jj = 0; jj < REMAINDER; ++jj) { + if (tmp[jj]) { + reg |= 1u << (COMPLETE * 8 + jj); + } + } + } + + // Store the predicate register. + preds[ii] = reg; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t pack_predicates(uint32_t const (&p)[N]) { + uint32_t tmp[1]; + pack_predicates(tmp, p); + return tmp[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// G E N E R I C P R E D I C A T E D L D G S T S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldgsts_(Functor& fct, uint32_t const (&preds)[M]) { + // The number of complete bytes (where we use all the predicates in a byte). + enum { COMPLETE = N / PREDS_PER_BYTE }; + + // Make sure we did allocate enough predicates. + static_assert(Div_up::VALUE <= M, ""); + + // The remainder. + enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; + + // Make sure we got the math right and the remainder is between 0 and 3. + static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); + + // The mask to extract the predicates. + enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; + +// Clear the fetch registers. +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + fct.clear(ii); + } + + // Run complete steps. + bool p[PREDS_PER_BYTE]; +#pragma unroll + for (int ii = 0; ii < COMPLETE; ++ii) { + // The predicate. + uint32_t reg = preds[ii / BYTES_PER_REG]; + + // Extract the predicates. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; + } + +// Issue the loads. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + fct.ldgsts(ii * PREDS_PER_BYTE + jj, p[jj]); + } + } + + // Skip the rest of the code if we do not have a remainder. + if (REMAINDER > 0) { + // The mask to extract the predicates. + enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; + + // The predicate register. + uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; + + // Extract the predicates. +#pragma unroll + for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) { + uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); + p[jj] = (reg & mask) != 0u; + } + +// Issue the loads. +#pragma unroll + for (int ii = 0; ii < REMAINDER; ++ii) { + fct.ldgsts(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldgsts_(Functor& fct, uint32_t preds) { + uint32_t tmp[1] = {preds}; + ldgsts_(fct, tmp); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint8_t& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint16_t& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint32_t& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint2& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint4& dst, void const* ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ldg_functor { + // Ctor. + inline __device__ Ldg_functor(Data_type (&fetch)[N], void const* (&ptrs)[N]) + : fetch_(fetch), ptrs_(ptrs) {} + + // Clear the element. + inline __device__ void clear(int ii) { fmha::clear(fetch_[ii]); } + + // Trigger the loads. + inline __device__ void ldgsts(int ii, bool p) { + if (p) { + ldg(fetch_[ii], ptrs_[ii]); + } + } + + // The fetch registers. + Data_type (&fetch_)[N]; + // The pointers. + void const* (&ptrs_)[N]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg_(Data_type (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + Ldg_functor fct(fetch, ptrs); + ldgsts_(fct, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint8_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint16_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint32_t (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint2 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldg(uint4 (&fetch)[N], void const* (&ptrs)[N], uint32_t (&preds)[M]) { + ldg_(fetch, ptrs, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldgdepbar() { + if (USE_LDGSTS) { + asm volatile("cp.async.commit_group;\n" ::); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void depbar_() { + if (USE_LDGSTS) { + asm volatile("cp.async.wait_group %0;\n" ::"n"(COUNT)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void depbar() { + if (USE_LDGSTS) { + int const VALUE = Max::VALUE; + asm volatile("cp.async.wait_group %0;\n" ::"n"(VALUE)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldgsts128(uint32_t dst, void const* src, bool p = true) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + uint32_t m = p ? 16u : 0u; + asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" ::"r"(dst), "l"(src), "r"(m)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Ldgsts_functor { + // Ctor. + inline __device__ Ldgsts_functor(uint32_t (&smem_ptrs)[N], void const* (&gmem_ptrs)[N]) + : smem_ptrs_(smem_ptrs), gmem_ptrs_(gmem_ptrs) {} + + // Does nothing. + inline __device__ void clear(int ii) {} + + // Trigger the load-store instruction. + inline __device__ void ldgsts(int ii, bool p) { ldgsts128(smem_ptrs_[ii], gmem_ptrs_[ii], p); } + + // The shared memory pointers. + uint32_t (&smem_ptrs_)[N]; + // The global memory pointers. + void const* (&gmem_ptrs_)[N]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void ldgsts(uint32_t (&dst)[N], void const* (&src)[N], uint32_t (&preds)[M]) { + Ldgsts_functor fct(dst, src); + ldgsts_(fct, preds); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint16_t& dst, uint32_t ptr) { + asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint32_t& dst, uint32_t ptr) { + asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint2& dst, uint32_t ptr) { + asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void lds(uint4& dst, uint32_t ptr) { + asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D S M +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint32_t& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(dst) : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint32_t& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint2& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint2& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst.x), "=r"(dst.y) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm(uint4& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsmt(uint4& dst, uint32_t ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) + : "r"(ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T S M +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsm(uint32_t ptr, uint32_t const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsmt(uint32_t ptr, uint32_t const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n" ::"r"(ptr), "r"(src)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsm(uint32_t ptr, uint2 const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), "r"(src.x), + "r"(src.y)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsmt(uint32_t ptr, uint2 const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n" ::"r"(ptr), + "r"(src.x), "r"(src.y)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsm(uint32_t ptr, uint4 const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), + "r"(src.x), "r"(src.y), "r"(src.z), "r"(src.w)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stsmt(uint32_t ptr, uint4 const& src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile( + "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(ptr), + "r"(src.x), "r"(src.y), "r"(src.z), "r"(src.w)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, float val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint8_t val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint16_t val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint32_t val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint2 val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void* ptr, uint4 val) { *reinterpret_cast(ptr) = val; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T S +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint16_t val) { + asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint32_t val) { + asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint2 val) { + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" : : "r"(ptr), "r"(val.x), "r"(val.y)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void sts(uint32_t ptr, uint4 val) { + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts_(uint32_t (&ptrs)[N], Data_type const (&data)[N]) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + sts(ptrs[ii], data[ii]); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], uint16_t const (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], uint32_t const (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], uint2 const (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void sts(uint32_t (&ptrs)[N], uint4 const (&data)[N]) { + sts_(ptrs, data); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#define __HALF2_TO_UI(var) *(reinterpret_cast(&(var))) +#define __HALF2_TO_CUI(var) *(reinterpret_cast(&(var))) + +static __device__ __inline__ void atomicAdd_half2(half2* const address, const half2 val) { + asm volatile("{ red.global.add.noftz.f16x2 [%0],%1; }\n" ::"l"(address), "r"(__HALF2_TO_CUI(val)) + : "memory"); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w) { +#if defined(USE_F2I_EMULATION_TRICK) + // Make sure the float is in the proper range. + float cx, cy, cz, cw; + if (CAN_BE_NEGATIVE) { + cx = fmha::clamp(x, -128.f, 127.f); + cy = fmha::clamp(y, -128.f, 127.f); + cz = fmha::clamp(z, -128.f, 127.f); + cw = fmha::clamp(w, -128.f, 127.f); + } else { + cx = fminf(x, 127.f); + cy = fminf(y, 127.f); + cz = fminf(z, 127.f); + cw = fminf(w, 127.f); + } + + // Re-add the magic number. + cx += FP32_I2F_MAGIC_NUMBER; + cy += FP32_I2F_MAGIC_NUMBER; + cz += FP32_I2F_MAGIC_NUMBER; + cw += FP32_I2F_MAGIC_NUMBER; + + // We need unsigned ints... + uint32_t a = reinterpret_cast(cx); + uint32_t b = reinterpret_cast(cy); + uint32_t c = reinterpret_cast(cz); + uint32_t d = reinterpret_cast(cw); + + // Pack the numbers. + uint32_t dst; + asm volatile("prmt.b32 %0, %1, %2, 0x0040;\n" : "=r"(dst) : "r"(a), "r"(b)); + asm volatile("prmt.b32 %0, %0, %1, 0x0410;\n" : "+r"(dst) : "r"(c)); + asm volatile("prmt.b32 %0, %0, %1, 0x4210;\n" : "+r"(dst) : "r"(d)); + return dst; +#else + uint32_t a; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x)); + uint32_t b; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y)); + uint32_t c; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z)); + uint32_t d; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w)); + + uint32_t dst; + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c)); + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a)); + return dst; +#endif // defined(USE_F2I_EMULATION_TRICK) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void swizzle_rows(uint32_t& a, uint32_t& b, uint32_t c, uint32_t d) { + asm volatile("prmt.b32 %0, %1, %2, 0x6420;\n" : "=r"(a) : "r"(c), "r"(d)); + asm volatile("prmt.b32 %0, %1, %2, 0x7531;\n" : "=r"(b) : "r"(c), "r"(d)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldsm_with_lds(uint2& data, uint32_t smem) { + int lane = threadIdx.x % 32; + data = {0, 0}; + uint4 v = {0, 0, 0, 0}; + uint32_t* a = reinterpret_cast(&v); + if (lane < 16) { + fmha::lds(v, smem); + } + int src_row = lane / 4; + int src_col = lane % 4; + for (int it = 0; it < 4; it++) { + uint32_t val = a[it]; + uint32_t x = __shfl_sync(uint32_t(-1), val, src_row); + __syncwarp(); + uint32_t y = __shfl_sync(uint32_t(-1), val, src_row + 8); + __syncwarp(); + if (it == src_col) { + data.x = x; + data.y = y; + } + } +} + +inline __device__ void ldsmt_with_lds(uint2& data, uint32_t smem) { + int lane = threadIdx.x % 32; + + uint4 tmp16{0, 0, 0, 0}; // 16B + + if (lane < 16) { + fmha::lds(tmp16, smem); + } + + uint16_t* tmp16c = reinterpret_cast(&tmp16); // 8x2B: we move pairs + + uint16_t* t = reinterpret_cast(&data); // 4x2B + + int const src_col = lane / 4; // 0 - 7 + int const src_row = (lane % 4) * 2; + +// we have to shuffle the values to distribute them in the warp +#pragma unroll + for (int it = 0; it < 8; it++) { + uint16_t val, x, y; + val = tmp16c[it]; + x = __shfl_sync(uint32_t(-1), val, src_row + 0); + __syncwarp(); + y = __shfl_sync(uint32_t(-1), val, src_row + 1); + __syncwarp(); + + if (src_col == it) { + t[0] = x; + t[1] = y; + } + val = tmp16c[it]; + x = __shfl_sync(uint32_t(-1), val, src_row + 8); + __syncwarp(); + y = __shfl_sync(uint32_t(-1), val, src_row + 9); + __syncwarp(); + + if (src_col == it) { + t[2] = x; + t[3] = y; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator& op) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator& op) { + float tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_reduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator& op) { +#pragma unroll + for (int mi = 0; mi < M; mi++) { + dst[mi] = src[mi]; + dst[mi] = Allreduce<4>::run(dst[mi], op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator& op) { + float tmp[M]; +#pragma unroll + for (int mi = 0; mi < M; mi++) { + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_allreduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t elect_one_sync() { + uint32_t pred = 0; +#if __CUDA_ARCH__ >= 900 +#if !defined(__CUDACC_RTC__) + uint32_t laneid = 0; + asm volatile( + "\n\ + {\n\ + .reg .b32 %rx;\n\ + .reg .pred %px;\n\ + elect.one.sync %rx|%px, %2;\n\ + @%px mov.s32 %1, 1;\n\ + mov.s32 %0, %rx;\n\ + }\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); +#else + pred = threadIdx.x == 0; +#endif +#endif + return pred; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t float2_to_e4m3x2(float x, float y) { +#if defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900)) + uint16_t res; + asm volatile("cvt.rn.e4m3x2.f32.satfinite %0, %2, %1;" : "=h"(res) : "f"(x), "f"(y)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t float4_to_e4m3x4(float x, float y, float z, float w) { +#if defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900)) + uint32_t res; + asm volatile( + "{\n" + ".reg .b16 lo;\n" + ".reg .b16 hi;\n" + "cvt.rn.e4m3x2.f32.satfinite lo, %2, %1;\n" + "cvt.rn.e4m3x2.f32.satfinite hi, %4, %3;\n" + "mov.b32 %0, {lo, hi};\n" + "}" + : "=r"(res) + : "f"(x), "f"(y), "f"(z), "f"(w)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t float4_to_e5m2x4(float x, float y, float z, float w) { +#if defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 890 && defined(FMHA_ENABLE_SM89_QMMA)) || (__CUDA_ARCH__ >= 900)) + uint32_t res; + asm volatile( + "{\n" + ".reg .b16 lo;\n" + ".reg .b16 hi;\n" + "cvt.rn.e5m2x2.f32.satfinite lo, %2, %1;\n" + "cvt.rn.e5m2x2.f32.satfinite hi, %4, %3;\n" + "mov.b32 %0, {lo, hi};\n" + "}" + : "=r"(res) + : "f"(x), "f"(y), "f"(z), "f"(w)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t half4_to_e4m3x4(uint32_t const h2_0, uint32_t const h2_1) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) + uint32_t res; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "cvt.satfinite.rn.e4m3x2.f16x2 lo, %1;\n" + "cvt.satfinite.rn.e4m3x2.f16x2 hi, %2;\n" + "mov.b32 %0, {lo, hi};\n" + "}\n" + : "=r"(res) + : "r"(h2_0), "r"(h2_1)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t half4_to_e5m2x4(uint32_t const h2_0, uint32_t const h2_1) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) + uint32_t res; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "cvt.satfinite.rn.e5m2x2.f16x2 lo, %1;\n" + "cvt.satfinite.rn.e5m2x2.f16x2 hi, %2;\n" + "mov.b32 %0, {lo, hi};\n" + "}\n" + : "=r"(res) + : "r"(h2_0), "r"(h2_1)); + return res; +#else + assert(false); + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Helpers to pack float4 into a destination register with 4 8bit values +template +inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, + float const w) { + return float4_to_char4(x, y, z, w); +}; + +template <> +inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, + float const w) { + return float4_to_e4m3x4(x, y, z, w); +}; + +template <> +inline __device__ uint32_t float4_to_8bitx4(float const x, float const y, float const z, + float const w) { + return float4_to_e5m2x4(x, y, z, w); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1); + +template <> +inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1) { + return half4_to_e4m3x4(h2_0, h2_1); +} + +template <> +inline __device__ uint32_t half4_to_fp8x4(uint32_t const h2_0, uint32_t const h2_1) { + return half4_to_e5m2x4(h2_0, h2_1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t float4_to_fp8x4(float const, float const, float const, float const); + +template <> +inline __device__ uint32_t float4_to_fp8x4(float const x, float const y, + float const z, float const w) { + return float4_to_e4m3x4(x, y, z, w); +} + +template <> +inline __device__ uint32_t float4_to_fp8x4(float const x, float const y, + float const z, float const w) { + return float4_to_e5m2x4(x, y, z, w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void fence_view_async_shared() { + // Issue a shared memory fence for async operations (FENCE.VIEW.ASYNC.S) + // only compiles on sm90+ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("fence.proxy.async.shared::cta;\n"); +#else + assert(false); +#endif +} + +inline __device__ void fence_view_async_global() { + // Issue a global memory fence for async operations (FENCE.VIEW.ASYNC.G) + // only compiles on sm90+ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("fence.proxy.async.global::cta;\n"); +#else + assert(false); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ char* align_1024(char* ptr) { + uint64_t address_bit = reinterpret_cast(ptr); + uint64_t offset = address_bit % 1024; + if (offset == 0) { + return ptr; + } else { + return ptr + (1024 - offset); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float atomicMaxFloatPos_(float* addr, float value) { + // VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION. + float old = __int_as_float(atomicMax((int*)addr, __float_as_int(value))); + return old; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ float max3Pos_(float const a, float const b, float const c) { + // VALUE MUST BE POSITIVE! USED ONLY FOR INTERNAL AMAX REDUCTION. + float res; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + int32_t a_ = reinterpret_cast(a); + int32_t b_ = reinterpret_cast(b); + int32_t c_ = reinterpret_cast(c); + int32_t tmp; + asm volatile("max.s16x2 %0, %1, %2;\n" : "=r"(tmp) : "r"(a_), "r"(b_)); + asm volatile("max.s16x2 %0, %0, %1;\n" : "+r"(tmp) : "r"(tmp), "r"(c_)); + res = reinterpret_cast(tmp); +#else + res = fmaxf(a, fmaxf(b, c)); +#endif + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Fast approximate tanh. +static inline __device__ float __tanhf(float x) { +#if (__CUDA_ARCH__ >= 750) + float r = x; + asm("tanh.approx.f32 %0, %0;" : "+f"(r)); + return r; +#else + return tanhf(x); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/warpspec/circular_buffer.h b/csrc/fmha_v2/fmha/warpspec/circular_buffer.h new file mode 100644 index 0000000000..903319490a --- /dev/null +++ b/csrc/fmha_v2/fmha/warpspec/circular_buffer.h @@ -0,0 +1,399 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include + +#pragma once + +namespace fmha { +namespace ws { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* Shared storage for barriers needed by both producer and consumer */ +template +struct CircularBufferBarriers { + __align__(8) uint64_t entryProducedBarriers[DEPTH]; + __align__(8) uint64_t entryConsumedBarriers[DEPTH]; + + CircularBufferBarriers() = default; + // CircularBufferBarriers must live in __shared__ -- cannot copy + CircularBufferBarriers(CircularBufferBarriers const& other) = delete; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* Producer class */ +template +class CircularBufferWriter { + protected: + uint32_t _wptr; + uint32_t _phase; + fmha::Arrive_wait _entryConsumedBarriers; + fmha::Arrive_wait _entryProducedBarriers; + + public: + inline __device__ CircularBufferWriter(CircularBufferBarriers* barriers) + : _entryProducedBarriers(barriers->entryProducedBarriers), + _entryConsumedBarriers(barriers->entryConsumedBarriers), + _wptr(0), + _phase(0xffffffff) {} + + inline __device__ int ptr() { return _wptr; } + + // Return the equivalent read phase. + inline __device__ int phase() { return _phase ^ 0xffffffff; } + + /* Reserve space in the buffer for TMA */ + inline __device__ int tmaReserve(int tid0, int transactioncnt) { + int ptr = threadReserve(); + _entryProducedBarriers.bar_arrive_set_transactioncnt(ptr, transactioncnt, tid0); + return ptr; + } + + /* Reserve space in the buffer for producer threads */ + inline __device__ int threadReserve() { + wait(); + return advance(); + } + + inline __device__ int advance() { + int rval = _wptr; + _phase ^= (1 << _wptr); + _wptr += 1; + if (_wptr >= DEPTH) { + _wptr = 0; + } + return rval; + } + + /* Wait for space to become available in the buffer */ + inline __device__ int wait() { + int ready = _entryConsumedBarriers.bar_peek(_wptr, (_phase >> _wptr) & 1); + if (!ready) _entryConsumedBarriers.bar_wait(_wptr, (_phase >> _wptr) & 1); + return _wptr; + } + + /* Signal that data is ready */ + inline __device__ void threadCommit(int tid0, int id) { + if (tid0) { + _entryProducedBarriers.bar_arrive_normal(id); + } + } + + /* Get the barrier address, needed by TMA */ + inline __device__ uint64_t* barrier_ptr(int id) { + return _entryProducedBarriers.get_bar_addr(id); + } + + inline __device__ void setPtr(int ptr) { _wptr = ptr; } + + inline __device__ void setPhase(int phase) { _phase = phase; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* Consumer class */ +template +class CircularBufferReader { + private: + uint32_t _rptr; + uint32_t _phase; + + public: + fmha::Arrive_wait _entryProducedBarriers; + fmha::Arrive_wait _entryConsumedBarriers; + + inline __device__ CircularBufferReader(CircularBufferBarriers* barriers) + : _entryProducedBarriers(barriers->entryProducedBarriers), + _entryConsumedBarriers(barriers->entryConsumedBarriers), + _rptr(0), + _phase(0) {} + + inline __device__ void setProducerCta(int cta_id) { + _entryConsumedBarriers.set_bar_base_dsmem(cta_id); + } + + /* Peek at the head */ + inline __device__ int peek() { + return _entryProducedBarriers.bar_peek(_rptr, (_phase >> _rptr) & 1); + } + + /* Wait for the head to be ready */ + inline __device__ int wait() { + _entryProducedBarriers.bar_wait(_rptr, (_phase >> _rptr) & 1); + return _rptr; + } + + /* Advance the head pointer */ + inline __device__ void advance() { + _phase ^= (1 << _rptr); + _rptr += 1; + if (_rptr >= DEPTH) { + _rptr = 0; + } + } + + inline __device__ int ptr() { return _rptr; } + + inline __device__ uint32_t phase() { return _phase; } + + /* Indicate consumption of data at specified pointer. + The producer is now free to overwrite it + */ + inline __device__ void complete(int tid0, int ptr) { + if (tid0) { + if (CGA_SIZE > 1) { + _entryConsumedBarriers.bar_arrive_dsmem(ptr); + } else { + _entryConsumedBarriers.bar_arrive_normal(ptr); + } + } + } + + /* Simplification of complete and advance for cases + where they don't need to be reordered/separated for performance + */ + inline __device__ void pop(int tid0) { + complete(tid0, _rptr); + advance(); + } + + /* Overrides for pointer and phase. Used for shared buffers */ + inline __device__ void setPtr(int ptr) { _rptr = ptr; } + + inline __device__ void setPhase(uint32_t phase) { _phase = phase; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CircularBuffer { + protected: + CircularBufferBarriers _barriers; + + public: + inline __device__ void init(int tid0, int producer_thread_count, int consumer_thread_count) { + if (tid0) { + for (int i = 0; i < DEPTH; i++) { + fmha::bar_create(&_barriers.entryProducedBarriers[i], producer_thread_count); + fmha::bar_create(&_barriers.entryConsumedBarriers[i], consumer_thread_count); + } + } + } + + using Reader = CircularBufferReader; + using Writer = CircularBufferWriter; + + inline __device__ Reader createReader() { return Reader(&_barriers); } + + inline __device__ Writer createWriter() { return Writer(&_barriers); } + + inline __device__ int depth() { return DEPTH; } + + CircularBuffer() = default; + // CircularBuffer must live in __shared__ -- cannot copy + CircularBuffer(CircularBuffer const& other) = delete; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CircularBufferWithDataReader : public CircularBufferReader { + private: + T* _data; + + public: + inline __device__ CircularBufferWithDataReader(CircularBufferBarriers* barriers, T* data) + : CircularBufferReader(barriers), _data(data) {} + + inline __device__ T read() { return _data[this->ptr()]; } + + inline __device__ T pop(int tid0, bool read_data = true) { + T rval; + int ready = this->peek(); + if (!ready) this->wait(); + if (read_data) { + rval = read(); + fmha::fence_view_async_shared(); + } + this->complete(tid0, this->ptr()); + this->advance(); + return rval; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CircularBufferWithDataWriter : public CircularBufferWriter { + private: + T* _data; + + public: + inline __device__ CircularBufferWithDataWriter(CircularBufferBarriers* barriers, T* data) + : CircularBufferWriter(barriers), _data(data) {} + + inline __device__ void write(int ptr, T const& wrdat) { _data[ptr] = wrdat; } + + inline __device__ int push(int tid0, T const& wrdat, bool writeData = true, + uint32_t transactioncnt = 0) { + int ptr = this->threadReserve(); + if (tid0 && writeData) { + write(ptr, wrdat); + __threadfence_block(); + } + if (transactioncnt == 0) + this->threadCommit(tid0, ptr); + else + this->_entryProducedBarriers.bar_arrive_set_transactioncnt(ptr, transactioncnt, tid0); + return ptr; + } + + template + inline __device__ int push_with_sync(int tid0, T const& wrdat, bool writeData = true, + uint32_t transactioncnt = 0) { + int ptr = this->threadReserve(); + named_barrier_wait(SYNC_BAR, SYNC_THREADS); + if (tid0 && writeData) { + write(ptr, wrdat); + __threadfence_block(); + } + if (transactioncnt == 0) + this->threadCommit(tid0, ptr); + else + this->_entryProducedBarriers.bar_arrive_set_transactioncnt(ptr, transactioncnt, tid0); + return ptr; + } + + inline __device__ void broadcast(T const& wrdat) { + int offset = this->threadReserve(); + for (int i = 0; i < CGA_SIZE; i++) { + push_to_cta(wrdat, i, offset); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CircularBufferWithData : public CircularBuffer { + private: + T _data[DEPTH]; + + public: + inline __device__ T* data() { return _data; } + + using Reader = CircularBufferWithDataReader; + using Writer = CircularBufferWithDataWriter; + + inline __device__ Reader createReader() { return Reader(&this->_barriers, _data); } + + inline __device__ Writer createWriter() { return Writer(&this->_barriers, _data); } + + CircularBufferWithData() = default; + // Must live in __shared__ -- cannot copy + CircularBufferWithData(CircularBufferWithData const& other) = delete; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct OrderedMutex { + uint64_t barriers[2]; + + inline __device__ void init(int tid0, int threads0, int threads1) { + if (tid0) { + fmha::bar_create(&barriers[0], threads0); + fmha::bar_create(&barriers[1], threads1); + } + } +}; + +class OrderedMutexAccessor { + private: + int _phase; + int _id; + int _barrier_id; + + fmha::Arrive_wait _barriers; + + public: + inline __device__ OrderedMutexAccessor(OrderedMutex& m, int id, int barrier_id) + : _phase(0), _id(id), _barriers(m.barriers), _barrier_id(barrier_id) {} + + inline __device__ void arrive() { _barriers.bar_arrive_normal(_id); } + + inline __device__ void wait() { + int ready = _barriers.bar_peek(_id ^ 1, _phase); + if (!ready) { + _barriers.bar_wait(_id ^ 1, _phase); + } + _phase ^= 1; + } + + inline __device__ void named_bar_arrive() { + // ... + // Softmax ends + // Make sure barrier is not moving around + if (_id == 0) { + named_barrier_wait(_barrier_id, 256); + } + } + + inline __device__ void named_bar_wait() { + // Make sure barrier is not moving around + if (_id == 1) { + named_barrier_wait(_barrier_id, 256); + } + // Softmax starts + // ... + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ComputeGroupBarrier { + uint64_t barrier; + + inline __device__ void init(int tid0, int threads) { + if (tid0) { + fmha::bar_create(&barrier, threads); + } + } +}; + +class ComputeGroupBarrierAccessor { + private: + int _phase; + fmha::Arrive_wait _barrier; + + public: + inline __device__ ComputeGroupBarrierAccessor(ComputeGroupBarrier& m) + : _phase(0), _barrier(&m.barrier) {} + + inline __device__ void arrive() { _barrier.bar_arrive_normal(0); } + + inline __device__ void wait() { + int ready = _barrier.bar_peek(0, _phase); + if (!ready) { + _barrier.bar_wait(0, _phase); + } + _phase ^= 1; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace ws +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/warpspec/compute.h b/csrc/fmha_v2/fmha/warpspec/compute.h new file mode 100644 index 0000000000..9aae70b2e7 --- /dev/null +++ b/csrc/fmha_v2/fmha/warpspec/compute.h @@ -0,0 +1,606 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "fmha/alibi_params.h" +#include "fmha/hopper/fragment.h" +#include "fmha/hopper/utils_warpgroup.h" +#include "fmha/softmax.h" +#include "fmha/warpspec/circular_buffer.h" +#include "fmha/warpspec/dma.h" +#include "fmha/warpspec/epilogue.h" + +namespace fmha { +namespace ws { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // Template instruction traits to specialize structs + template class Instruction_traits, + // Kernel Traits + typename Kernel_traits> +struct Compute { + // The shared struct. + using Shared = typename Kernel_traits::Shared; + + // The q, or kv tile reader. + using Circular_buffer_q_reader = typename Kernel_traits::Circular_buffer_q_reader; + using Circular_buffer_kv_reader = typename Kernel_traits::Circular_buffer_kv_reader; + + // The instruction traits for BMM1. + using Traits_p = typename Kernel_traits::Traits_p; + // The instruction traits for BMM2. + using Traits_o = typename Kernel_traits::Traits_o; + + // The CTA description for BMM1. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The CTA description for BMM2. + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + + // The Q shared memory tile. + using Smem_tile_q = typename Kernel_traits::Smem_tile_q; + // The K shared memory tile. + using Smem_tile_k = typename Kernel_traits::Smem_tile_k; + // The V shared memory tile. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The GMMA compute tile for BMM1. + using Compute_tile_p = typename Kernel_traits::Compute_tile_p; + // The GMMA compute tile for BMM2. + using Compute_tile_o = typename Kernel_traits::Compute_tile_o; + + // The MMA tile for the BMM1. + using Mma_tile_p = typename Kernel_traits::Mma_tile_p; + // The MMA tile for the BMM2. + using Mma_tile_o = typename Kernel_traits::Mma_tile_o; + + // The fragment of BMM1 output. + using Fragment_p = typename Compute_tile_o::Fragment; + + // The global memory tile for storing BMM2 output. + using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; + + // Softmax + using Softmax = Softmax; + + // BMM2 epilogue + using Tile_o_epilogue = Tile_o_epilogue; + + // The step size of Q loop. + enum { STEP_Q = Kernel_traits::STEP_Q }; + + // The step size of KV loop. + enum { STEP_KV = Kernel_traits::STEP_KV }; + + // The number of compute groups (currently fixed at 2). + enum { NUM_COMPUTE_GROUPS = Kernel_traits::NUM_COMPUTE_GROUPS }; + + // Whether we skip those masked tiles when causal mask is enabled ? + enum { SKIP_CAUSAL_MASK_TILES = Kernel_traits::CAUSAL_MASK && !Kernel_traits::USE_CUSTOM_MASK }; + + // Whether we attend to the specific sliding window or chunk ? + enum { SLIDING_OR_CHUNKED_ATTENTION = Kernel_traits::SLIDING_OR_CHUNKED_ATTENTION }; + + // Are we applying alibi bias (drop FMA optimizations for accuracy reasons). + enum { APPLY_ALIBI = Kernel_traits::APPLY_ALIBI }; + + // Do we use custom mask input ? + enum { USE_CUSTOM_MASK = Kernel_traits::USE_CUSTOM_MASK }; + + // Do we always need to apply the mask ? + enum { ALWAYS_APPLY_MASK = APPLY_ALIBI || USE_CUSTOM_MASK }; + + // Enable mutex for overlapping mma and softmax instructions. + enum { ENABLE_MUTEX = Kernel_traits::ENABLE_MUTEX }; + + // The head_dimension groups. + enum { D_GROUPS = Kernel_traits::D_GROUPS }; + + // The MMA_K groups (corresponding to head_dimension groups). + enum { BMM1_MMAS_K_GROUPS = Kernel_traits::D_GROUPS }; + + // The number of MMAS_K for each head_dimension group. + enum { BMM1_MMAS_K_PER_GROUP = Mma_tile_p::MMAS_K / BMM1_MMAS_K_GROUPS }; + + // The MMA_K groups (corresponding to kv_step groups). + enum { BMM2_MMAS_K_GROUPS = Kernel_traits::BMM2_K_GROUPS }; + + // The number of MMAS_K for each head_dimension group. + enum { BMM2_MMAS_K_PER_GROUP = Mma_tile_o::MMAS_K / BMM2_MMAS_K_GROUPS }; + + // The tile size of V after head_dimension split. + enum { TILE_SIZE_V_PER_D_GROUP = STEP_KV * Kernel_traits::D_PER_GROUP }; + + enum { TILE_SIZE_V = STEP_KV * Kernel_traits::DV }; + + enum { TILE_BYTES_V_PER_D_GROUP = STEP_KV * Kernel_traits::D_BYTES_PER_GROUP }; + + enum { TILE_BYTES_V_PER_K_GROUP = BMM2_MMAS_K_PER_GROUP * Kernel_traits::D_BYTES_PER_GROUP }; + + // Named barrier for inter-warpgroup sync + enum { SYNC_BARRIER = Kernel_traits::MMA_SYNC_BARRIER_ID }; + + // Whether Q and KV is in separate buffer, which means we need to consider different Q and KV + // lengths. + enum { SEPARATE_Q_KV_BUFFER = Kernel_traits::SEPARATE_Q_KV_BUFFER }; + + enum { SAGE_BLOCK_SIZE_Q = Kernel_traits::SAGE_BLOCK_SIZE_Q }; + + // sanitize 0 to -1, avoid DIV BY ZERO below + enum { + SAGE_BLOCK_SIZE_K = Kernel_traits::SAGE_BLOCK_SIZE_K > 0 ? Kernel_traits::SAGE_BLOCK_SIZE_K : -1 + }; + + enum { + SAGE_BLOCK_SIZE_V = Kernel_traits::SAGE_BLOCK_SIZE_V > 0 ? Kernel_traits::SAGE_BLOCK_SIZE_V : -1 + }; + + // BLOCK_SIZE_Q should be multiply of STEP_Q (usually 64) so that q scale can be fused into + // scale_bmm1 + static_assert(SAGE_BLOCK_SIZE_Q < 0 || SAGE_BLOCK_SIZE_Q % STEP_Q == 0); + static_assert(SAGE_BLOCK_SIZE_K < 0 || SAGE_BLOCK_SIZE_K % 8 == 0); // 8 = columns of a gmma CORE + static_assert(SAGE_BLOCK_SIZE_V < 0 || + SAGE_BLOCK_SIZE_V % 32 == 0); // 32 = K dimension of a qgmma + + // SAGE_BLOCKS_PER_STEP_X is used to declare scale buffer like `float + // scales_k[SAGE_BLOCKS_PER_STEP_K];` if SAGE_BLOCKS_PER_STEP_X == 0, you will get `zero-sized + // variable is not allowed in device code` error from nvcc, so the minimal value have to be 1. But + // don't worry, unused local variables will be optimized out by compiler. + enum { SAGE_BLOCKS_PER_STEP_K = std::max(STEP_KV / SAGE_BLOCK_SIZE_K, 1) }; + + enum { SAGE_BLOCKS_PER_STEP_V = std::max(STEP_KV / SAGE_BLOCK_SIZE_V, 1) }; + +#define K_TILE_WAIT() \ + int ready_k = cbr_k.peek(); \ + if (!ready_k) { \ + cbr_k.wait(); \ + } + +#define KV_TILE_COMPLETE() \ + cbr_k.complete(tidx == 0, cbr_k.ptr()); \ + cbr_v.complete(tidx == 0, cbr_v.ptr()); \ + cbr_k.advance(); \ + cbr_v.advance(); + +#define COMPUTE_SINGLE_TILE(IS_FIRST_COL, APPLY_MASK) \ + compute_single_tile( \ + params, ctile_p, softmax, ctile_o, p_max, p_sum, tidx, actual_kv_seqlen, alibi_head_scale, \ + USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \ + : (q_step_idx * STEP_Q + head_info.q_tile_offset), \ + kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \ + kv_step_idx == kv_idx_end - 1); + + //////////////////////////////////////////////////////////////////////////////////////////////// + + inline __device__ int div_up(int a, int b) { return (a + b - 1) / b; } + + //////////////////////////////////////////////////////////////////////////////////////////////// + + // Compute the kv_left_mask_end and kv_right_mask_start, where mask is applied when kv_idx < + // kv_left_mask_end or kv_idx >= kv_right_mask_start. + template + inline __device__ std::pair compute_kv_mask_start_end(Params const& params, + int const tile_offset_start, + int const tile_offset_end, + int const kv_idx_end) { + // The kv_left_mask_end is 0 by default. + int kv_left_mask_end = 0; + // The kv_right_mask_start is kv_idx_end - 1 by default, which means only the last kv tile is + // masked. + int kv_right_mask_start = kv_idx_end - 1; + + // Always apply mask is specified. + if constexpr (ALWAYS_APPLY_MASK) { + return std::make_pair(0, 0); + } + + // Is the chunked_attention used ? + bool is_chunked_attention = params.log2_chunked_attention_size > 0; + + // The left mask is needed when we attend to a specific sliding window or chunk. + if constexpr (SLIDING_OR_CHUNKED_ATTENTION) { + // The kv_left_mask_end is the start of the chunk. + kv_left_mask_end = + div_up(is_chunked_attention ? ((tile_offset_end >> params.log2_chunked_attention_size) + << params.log2_chunked_attention_size) + : (tile_offset_end + 1 - params.sliding_window_size), + STEP_KV); + } + + // The right mask is needed when causal mask (including sliding_window_attention or chunked + // attention) is used. + if constexpr (SKIP_CAUSAL_MASK_TILES) { + kv_right_mask_start = tile_offset_start / STEP_KV; + } + + // Return the kv_left_mask_end and kv_right_mask_start. + return std::make_pair(kv_left_mask_end, kv_right_mask_start); + } + + //////////////////////////////////////////////////////////////////////////////////////////////// + + template + inline __device__ void run(int warpgroup_id, int tidx, Shared* shared, Params const& params) { + auto head_tracker = shared->head_info_tracker[warpgroup_id].createReader(); + auto cbr = shared->tma_q_tracker[warpgroup_id].createReader(); + + auto cbr_k = shared->tma_k_tracker.createReader(); + auto cbr_v = shared->tma_v_tracker.createReader(); + + // Ctile_p initialize (relies on q_stage, kv_stage). + char* smem_q = reinterpret_cast(&shared->smem_q[warpgroup_id][0]); + char* smem_k = reinterpret_cast(&shared->smem_k[0]); + Compute_tile_p ctile_p(smem_q, smem_k); + + // Softmax + Softmax softmax(params, tidx); + + // Ctile_o initialize (relies on kv_stage). + uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]); + Compute_tile_o ctile_o(0, smem_v); + + // Mutex between two compute groups. + OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER); + // Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions). + if (ENABLE_MUTEX && warpgroup_id == 1 && Kernel_traits::ELEMENT_BYTES == 2) { + mutex_accessor.arrive(); + } + + // While loop for different heads. + while (true) { + typename Shared::Head_info head_info = head_tracker.pop(true); + + if (head_info.kv_steps == -1) { + break; + } + + int const kv_steps = head_info.kv_steps; + int const q_steps = head_info.q_steps; + int const local_q_tile_offset = head_info.local_q_tile_offset; + // The global q tile offset (based on past kv cache). + // Not used by custom mask input. + int const q_tile_offset = + SEPARATE_Q_KV_BUFFER ? head_info.q_tile_offset : head_info.local_q_tile_offset; + int const actual_q_seqlen = head_info.actual_seqlen; + // Contiguous QKV FMHA assumes q, and kv have the same sequence length. + int const actual_kv_seqlen = + SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen; + + // Calculate the alibi head_scaling_factor. + float alibi_head_scale = APPLY_ALIBI ? get_alibi_head_scaling_factor( + head_info.bidh, params.alibi_params) + : 0.f; + // pre-compute the row of the scale for reuse + int sage_scale_row; + if constexpr (Kernel_traits::SAGE_ATTENTION) { + sage_scale_row = head_info.bidb * params.h + head_info.bidh; + } + + // BMM2 epilogue + Tile_o_epilogue tile_o_epilogue(params, head_info); + + int q_step_idx = warpgroup_id; + + // Compute work. + for (; q_step_idx < q_steps; q_step_idx += NUM_COMPUTE_GROUPS) { + // Check whether it is a valid run of q steps. + int const q_offset = q_step_idx * STEP_Q + local_q_tile_offset; + bool const valid_run = q_offset < actual_q_seqlen; + // fuse the scale of q into scale_bmm1 + if constexpr (SAGE_BLOCK_SIZE_Q > 0) { + // I tried another implementation here: store original `scale_bmm1` to a local variable + // to avoid frequent `__ldg`. But experiment shows that the current one is faster. + // A bit counterintuitive. + auto const scale_bmm1 = + params.scale_bmm1_d ? __ldg(params.scale_bmm1_d) : params.scale_bmm1; + int const idx = sage_scale_row * params.sage.q.max_nblock + q_offset / SAGE_BLOCK_SIZE_Q; + *(float*)(&softmax.scale_bmm1_) = + reinterpret_cast(scale_bmm1) * __ldg(¶ms.sage.q.scales[idx]); + } + + // KV tile is shared by two q tiles, + // so we need to consider the last compute group's q tile. + int const tile_offset_start = q_step_idx * STEP_Q + q_tile_offset; + int const tile_offset_end = tile_offset_start + STEP_Q - 1; + int const warpgroup_tile_offset_start = tile_offset_start - warpgroup_id * STEP_Q; + int const warpgroup_tile_offset_end = + tile_offset_start + (NUM_COMPUTE_GROUPS - warpgroup_id) * STEP_Q - 1; + + // Compute the kv_idx start (inclusive) and end (exclusive). + auto const [kv_idx_start, kv_idx_end] = DMA::Device::compute_kv_tile_idx( + params, warpgroup_tile_offset_start, warpgroup_tile_offset_end, kv_steps); + + // Compute the kv_left_mask_end and kv_right_mask_start, where mask is applied when kv_idx < + // kv_left_mask_end or kv_idx >= kv_right_mask_start. + auto const [kv_left_mask_end, kv_right_mask_start] = + compute_kv_mask_start_end(params, tile_offset_start, tile_offset_end, kv_idx_end); + + // The gmem O tile. + Gmem_tile_o gmem_o(params, head_info, *shared, tidx, + q_step_idx * STEP_Q + local_q_tile_offset); + + // Q ready to use in smem. + int ready = cbr.peek(); + if (!ready) { + cbr.wait(); + } + + static_assert(Mma_tile_p::CORES_M == 2); + float p_max[Mma_tile_p::CORES_M]; + float p_sum[Mma_tile_p::CORES_M]; + + int kv_step_idx = kv_idx_start; + // First K tiles ready to use in smem. + K_TILE_WAIT(); + // Need to apply mask if only kv tile exists. + if (kv_idx_start < kv_left_mask_end || kv_idx_start >= kv_right_mask_start) { + COMPUTE_SINGLE_TILE(true, true); + } else { + COMPUTE_SINGLE_TILE(true, false); + } + KV_TILE_COMPLETE(); + + for (kv_step_idx += 1; kv_step_idx < kv_right_mask_start; ++kv_step_idx) { + // Current step's K tiles ready to use in smem. + K_TILE_WAIT(); + + // Move kv tile to next buffer. + if (D_GROUPS > 1) { + ctile_p.increment_gmma_desc_group(); + } else { + ctile_p.increment_gmma_desc_b_group(); + } + + ctile_o.increment_gmma_desc_group(); + + // Apply the start mask only when sliding window attention is enabled. + if (kv_step_idx < kv_left_mask_end) { + COMPUTE_SINGLE_TILE(false, true); + } else { + COMPUTE_SINGLE_TILE(false, false); + } + + KV_TILE_COMPLETE(); + } + + // Always apply the mask in the end. + for (; kv_step_idx < kv_idx_end; ++kv_step_idx) { + // Current step's K tiles ready to use in smem. + K_TILE_WAIT(); + + // Move kv tile to next buffer. + if (D_GROUPS > 1) { + ctile_p.increment_gmma_desc_group(); + } else { + ctile_p.increment_gmma_desc_b_group(); + } + + ctile_o.increment_gmma_desc_group(); + + COMPUTE_SINGLE_TILE(false, true); + + KV_TILE_COMPLETE(); + } + if (valid_run) { + // Final step's update. + tile_o_epilogue.scale(ctile_o, p_max, p_sum); + // Store o_tile to gmem. + gmem_o.store(ctile_o.acc_); + } + + // Move q, kv to next buffer. + ctile_p.increment_gmma_desc_a_group(); + ctile_p.increment_gmma_desc_b_group(); + ctile_o.increment_gmma_desc_group(); + + if constexpr (Kernel_traits::RETURN_SOFTMAX_STATS) { + using Mma_tile = typename Traits_p::template Mma_tile; + fmha::Softmax_saver_tma saver(params, head_info); + saver.store(p_sum, p_max, sqrtf(params.d), q_step_idx * STEP_Q, valid_run); + } + } + } + } + + //////////////////////////////////////////////////////////////////////////////////////////////// + + template + inline __device__ void compute_single_tile( + Params params, Compute_tile_p& ctile_p, Softmax& softmax, Compute_tile_o& ctile_o, + float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M], int const tidx, + int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset, + int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, + Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, bool complete = false) { +// load the scales of K/V from global memory +#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \ + if constexpr (block_size > 0) { \ + const int _start = col_offset / block_size; \ + const float* _src = \ + params.sage.which.scales + sage_scale_row * params.sage.which.max_nblock + _start; \ + const int _end = params.sage.which.max_nblock - _start; \ + _Pragma("unroll") for (int _i = 0; _i < blocks_per_step; _i++) { \ + dst[_i] = _i < _end ? _src[_i] : 1.0f; \ + } \ + } + +#define LOAD_SCALES_K(scales) LOAD_SCALES_KV(scales, k, SAGE_BLOCKS_PER_STEP_K, SAGE_BLOCK_SIZE_K) + +#define LOAD_SCALES_V(scales) LOAD_SCALES_KV(scales, v, SAGE_BLOCKS_PER_STEP_V, SAGE_BLOCK_SIZE_V) + + // Load the needed packed masks. + softmax.load_packed_mask(row_offset, col_offset); + + // experiments show that here is the best place to load scales of K + float scales_k[SAGE_BLOCKS_PER_STEP_K]; + LOAD_SCALES_K(scales_k) + + // Wait until another warpgroup has already executed HGMMA. + if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 2) { + mutex.wait(); + } + + // Ctile_p is only used once by each n step. + ctile_p.clear(); + + // BMM1 (Q x K'). + warpgroup_arrive(); + +// Only single K groups when sizeof(D) <= 128B. +#pragma unroll + for (int kbi = 0; kbi < BMM1_MMAS_K_GROUPS - 1; kbi++) { +#pragma unroll + for (int ki = 0; ki < BMM1_MMAS_K_PER_GROUP; ki++) { + ctile_p.compute(ki, false, ki == BMM1_MMAS_K_PER_GROUP - 1); + } + ctile_p.increment_gmma_desc_group(); + } + +#pragma unroll + for (int ki = 0; ki < BMM1_MMAS_K_PER_GROUP - 1; ki++) { + ctile_p.compute(ki); + } + + ctile_p.compute(BMM1_MMAS_K_PER_GROUP - 1, true, true); + + warpgroup_commit(); + warpgroup_wait<0>(); + + // Arrive when the last tile consumes the q tile. + if (complete) { + cbr.complete(tidx == 0, cbr.ptr()); + cbr.advance(); + } + + if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 2) { + // Notify another warpgroup to execute HGMMA. + mutex.arrive(); + } + if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1) { + // Wait until another warpgroup has already executed QGMMA. + mutex.named_bar_wait(); + } + + // Fragment p for BMM2 input + Fragment_p frag_p[Mma_tile_o::MMAS_K]; + + // Unpack the elements from bmm1 output to floats. + softmax.unpack(ctile_p); + // apply the scales of K before softmax + if constexpr (SAGE_BLOCK_SIZE_K > 0) { +#pragma unroll + for (int ni = 0; ni < Mma_tile_p::CORES_N; ni++) { + float const scale_k = scales_k[SAGE_BLOCKS_PER_STEP_K * ni / Mma_tile_p::CORES_N]; +#pragma unroll + for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++) { + softmax.elt_[mi][2 * ni] *= scale_k; + softmax.elt_[mi][2 * ni + 1] *= scale_k; + } + } + } + + // Apply the alibi and mask. + softmax.apply_alibi_and_mask(ctile_p, params.alibi_params, alibi_head_scale, + actual_kv_seqlen, row_offset, col_offset); + + // Softmax Exp, max/sum, and update scales. + softmax.compute_and_update_scale(p_max, p_sum); + + // experiments show that here is the best place to load scales of V + float scales_v[SAGE_BLOCKS_PER_STEP_V]; + LOAD_SCALES_V(scales_v) + + // Update flash attention scales and pack it for BMM2 + softmax.pack(ctile_o, frag_p); + + if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1) { + // Notify another warpgroup to execute QGMMA. + mutex.named_bar_arrive(); + } + + // Wait until v buffer is ready. + int ready = cbr_v.peek(); + if (!ready) { + cbr_v.wait(); + } + + warpgroup_arrive(); + + float last_scale_v; + +// Apply the scale of V to partial result. +// Note 2 points: +// 1. Because the matrix V is quantized along the inner dimension, it is necessary to interrupt +// the MMA workflow after processing each BLOCKS_SIZE_V rows of V and scale the intermediate +// results once. For example, STEP_KV=256, qgmma.K=32, then 256/32=8 MMAs are needs, +// so mma_ki = [0,1,2, ..., 7]. If the BLOCK_SIZE_V=64, then after each 2 qgmmas we should scale +// ctile_o. +// 2. The ctile_o is all zero at the beginning. if we directly apply the scale of V after each 2 +// qgmmas, let's see what happens: +// ctile_o = [0] +// ctile_o = (ctile_o + P0 x V0) * s0 = P0 x V0 * s0 +// ctile_o = (ctile_o + P1 x V1) * s1 = P0 x V0 * s0 * s1 + P1 x V1 * s1 +// ctile_o = (ctile_o + P2 x V2) * s2 = P0 x V0 * s0 * s1 * s2 + P1 x V1 * s1 * s2 + P2 x V2 * +// s2 +// ... +// As you see, the actual scale of a V block is the cumulative product of the scales of all +// later blocks. To solve this, we have to preprocess the scale s[i] of block[i] to s[i]/s[i+1], +// and the final block uses the actual scale. +// But to fetch the next scale in next STEP leads to bad performance. So we apply s[i-1]/s[i] to +// current partial result BEFORE each V block. +#define APPLY_SCALE_V(mma_ki) \ + if constexpr (SAGE_BLOCK_SIZE_V > 0) { \ + if (mma_ki % (Mma_tile_o::MMAS_K / SAGE_BLOCKS_PER_STEP_V) == 0) { \ + float _scale_v = scales_v[SAGE_BLOCKS_PER_STEP_V * mma_ki / Mma_tile_o::MMAS_K]; \ + if (mma_ki != 0) { \ + warpgroup_commit(); \ + warpgroup_wait<0>(); \ + } \ + last_scale_v = _scale_v; \ + } \ + } + +// BMM2 (S * V). +#pragma unroll + for (int kbi = 0; kbi < BMM2_MMAS_K_GROUPS - 1; kbi++) { +#pragma unroll + for (int ki = 0; ki < BMM2_MMAS_K_PER_GROUP; ++ki) { + int const mma_ki = kbi * BMM2_MMAS_K_PER_GROUP + ki; + APPLY_SCALE_V(mma_ki) + ctile_o.fill_frag_a(frag_p[mma_ki]); + ctile_o.compute(ki, false, ki == BMM2_MMAS_K_PER_GROUP - 1); + } + ctile_o.increment_gmma_desc_group(); + } + +#pragma unroll + for (int ki = 0; ki < BMM2_MMAS_K_PER_GROUP - 1; ++ki) { + int const mma_ki = (BMM2_MMAS_K_GROUPS - 1) * BMM2_MMAS_K_PER_GROUP + ki; + APPLY_SCALE_V(mma_ki) + ctile_o.fill_frag_a(frag_p[mma_ki]); + ctile_o.compute(ki); + } + + APPLY_SCALE_V((Mma_tile_o::MMAS_K - 1)) + ctile_o.fill_frag_a(frag_p[Mma_tile_o::MMAS_K - 1]); + ctile_o.compute(Mma_tile_o::MMAS_K - 1, true, true); + + warpgroup_commit(); + warpgroup_wait<0>(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace ws +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/warpspec/dma.h b/csrc/fmha_v2/fmha/warpspec/dma.h new file mode 100644 index 0000000000..a14ccafdf3 --- /dev/null +++ b/csrc/fmha_v2/fmha/warpspec/dma.h @@ -0,0 +1,874 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include +#include +#include + +#include "fmha/hopper/arrive_wait.h" +#include "fmha/hopper/smem_tile.h" +#include "fmha/utils.h" + +namespace fmha { +namespace ws { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct DMA { + // The shared struct. + using Shared = typename Kernel_traits::Shared; + // The kv buffer writer. + using Circular_buffer_kv_writer = typename Kernel_traits::Circular_buffer_kv_writer; + using Circular_buffer_v_scratch_reader = typename Kernel_traits::Circular_buffer_v_scratch_reader; + using Circular_buffer_v_scratch_writer = typename Kernel_traits::Circular_buffer_v_scratch_writer; + + // The step size of Q loop. + enum { STEP_Q = Kernel_traits::STEP_Q }; + + // The step size of KV loop. + enum { STEP_KV = Kernel_traits::STEP_KV }; + + // The tile size of Q. + enum { TILE_SIZE_Q = STEP_Q * Kernel_traits::D }; + + // The tile size of Q after head_dimension split. + enum { TILE_SIZE_Q_PER_D_GROUP = STEP_Q * Kernel_traits::D_PER_GROUP }; + + // The tile size of K. + enum { TILE_SIZE_K = STEP_KV * Kernel_traits::D }; + + // The tile size of K after head_dimension split. + enum { TILE_SIZE_K_PER_D_GROUP = STEP_KV * Kernel_traits::D_PER_GROUP }; + + // The tile size of V. + enum { TILE_SIZE_V = STEP_KV * Kernel_traits::DV }; + + // The tile size of V after head_dimension split. + enum { TILE_SIZE_V_PER_D_GROUP = TILE_SIZE_K_PER_D_GROUP }; + + // Whether apply causal mask or not. + enum { CAUSAL_MASK = Kernel_traits::CAUSAL_MASK }; + + // Whether use custom mask input or not. + enum { USE_CUSTOM_MASK = Kernel_traits::USE_CUSTOM_MASK }; + + // Whether we skip those masked tiles when causal mask is enabled ? + enum { SKIP_CAUSAL_MASK_TILES = CAUSAL_MASK && !USE_CUSTOM_MASK }; + + // Whether we attend to the specific sliding window or chunk ? + enum { SLIDING_OR_CHUNKED_ATTENTION = Kernel_traits::SLIDING_OR_CHUNKED_ATTENTION }; + + // Is heads interleaved ? + enum { HEADS_INTERLEAVED = Kernel_traits::HEADS_INTERLEAVED }; + + // Named barrier for inter-warpgroup sync + enum { SYNC_BARRIER = Kernel_traits::DMA_SYNC_BARRIER_ID }; + + // The number of compute groups (currently fixed at 2). + enum { NUM_COMPUTE_GROUPS = Kernel_traits::NUM_COMPUTE_GROUPS }; + + // The tile scheduling mode: static (0), dynamic (1) + enum { SCHEDULING_MODE = Kernel_traits::SCHEDULING_MODE }; + + // Whether read from paged kv buffers or not. + enum { PAGED_KV_INPUT = Kernel_traits::PAGED_KV_INPUT }; + + // Whether the dma group transposes the v tile explicitly. + enum { DMA_GROUP_TRANSPOSE_V = Kernel_traits::DMA_GROUP_TRANSPOSE_V }; + + // How many threads get involved in the dma group. + enum { NUM_THREADS_IN_DMA_GROUP = Kernel_traits::NUM_THREADS_IN_DMA_GROUP }; + + // Transpose V + // K is the sequence length dimension (128 for GMMA). The unroll factor is decided according to + // empirical evidence so as to avoid register spill. + enum { K_ = STEP_KV % 128 == 0 ? 128 : 64 }; + + static_assert(STEP_KV % K_ == 0); + using Transposer = + Transposer 128 || SLIDING_OR_CHUNKED_ATTENTION) ? 1 : 2 /* UNROLL */>; + + struct Device { + // Only the warpgroup leader initiates mbarriers & TMA operations. + uint32_t elect_one_; + // The sum_s for q. + int sum_s_q_; + // The sum_s for kv. + int sum_s_kv_; + // Tile id for q tile scheduling + uint32_t tile_id_; + + inline __device__ Device(uint32_t elect_one) : elect_one_(elect_one) {} + + //////////////////////////////////////////////////////////////////////////////////////////// + + // Compute the kv tile idx start (inclusive) and end (exclusive). + static inline __device__ std::pair compute_kv_tile_idx( + bert::Fused_multihead_attention_params_v2 const& params, int q_step_offset, int q_step_end, + int kv_steps) { + // The default kv_idx_start and kv_idx_end (exclusive). + int kv_idx_start = 0; + int kv_idx_end = kv_steps; + + // Is the chunked_attention used ? + bool is_chunked_attention = params.log2_chunked_attention_size > 0; + + // Skip initial kv tiles due to sliding_window_size + if (SLIDING_OR_CHUNKED_ATTENTION) { + // The kv_offset_start. + int kv_offset_start = is_chunked_attention + ? ((q_step_offset >> params.log2_chunked_attention_size) + << params.log2_chunked_attention_size) + : max(0, q_step_offset + 1 - params.sliding_window_size); + kv_idx_start = kv_offset_start / STEP_KV; + } + + // Early stop when causal mask is enabled. + if (SKIP_CAUSAL_MASK_TILES) { + kv_idx_end = (q_step_end + STEP_KV - 1) / STEP_KV; + } + + return std::make_pair(kv_idx_start, kv_idx_end); + } + + //////////////////////////////////////////////////////////////////////////////////////////// + + // Packed contiguous QKV input. + inline __device__ void run_packed_qkv(bert::Fused_multihead_attention_params_v2 const& params, + Shared* shared) { + // DMA. + int local_wid = (threadIdx.x / 32) % 4; + int tiw = threadIdx.x % 32; + uint32_t smem_tile_id = __cvta_generic_to_shared(&shared->tile_id); + + if (SCHEDULING_MODE == 0) { + tile_id_ = blockIdx.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + + auto cbw0 = shared->tma_q_tracker[0].createWriter(); + auto cbw1 = shared->tma_q_tracker[1].createWriter(); + Circular_buffer_kv_writer cbw_k = shared->tma_k_tracker.createWriter(); + Circular_buffer_kv_writer cbw_v = shared->tma_v_tracker.createWriter(); + Circular_buffer_v_scratch_reader cbr_v_scratch = shared->tma_v_scratch_tracker.createReader(); + Circular_buffer_v_scratch_writer cbw_v_scratch = shared->tma_v_scratch_tracker.createWriter(); + auto headinfo_tracker0 = shared->head_info_tracker[0].createWriter(); + auto headinfo_tracker1 = shared->head_info_tracker[1].createWriter(); + + while (tile_id_ < params.num_tiles) { + // If we do bidh = next_head % h, we'd guarantee b to be spread across CTAs. + + int bidb, tmp, bidh, q_step_offset, q_steps; + + if (SCHEDULING_MODE == 0) { + bidh = tile_id_ % params.h; + bidb = tile_id_ / params.h; + } else { + // Balanced dynamic scheduling + if (CAUSAL_MASK && !SLIDING_OR_CHUNKED_ATTENTION && params.use_balanced_scheduling) { + q_step_offset = (params.num_tiles_per_head - 1 - tile_id_ / (params.b * params.h)) * + NUM_COMPUTE_GROUPS; + tmp = tile_id_ % (params.b * params.h); + bidh = tmp / params.b; + bidb = tmp % params.b; + q_steps = NUM_COMPUTE_GROUPS; + } else { // Unbalanced dynamic scheduling + bidb = tile_id_ / (params.h * params.num_tiles_per_head); + tmp = tile_id_ % (params.h * params.num_tiles_per_head); + bidh = tmp / params.num_tiles_per_head; + q_step_offset = tmp % params.num_tiles_per_head * NUM_COMPUTE_GROUPS; + q_steps = NUM_COMPUTE_GROUPS; + } + } + + cudaTmaDesc const* desc_q = ¶ms.tma_desc_q; + cudaTmaDesc const* desc_k = ¶ms.tma_desc_k; + cudaTmaDesc const* desc_v = ¶ms.tma_desc_v; + int actual_seqlen; + if (params.is_s_padded) { + sum_s_q_ = bidb * params.s; + actual_seqlen = params.cu_q_seqlens[bidb + 1] - params.cu_q_seqlens[bidb]; + } else { + sum_s_q_ = params.cu_q_seqlens[bidb]; + actual_seqlen = params.cu_q_seqlens[bidb + 1] - sum_s_q_; + } + sum_s_kv_ = sum_s_q_; + + // The cumulative packed_mask seqlens. + // Each sequence length in the batch has to be padded to multiple of 128. + int sum_mask_s = params.cu_mask_rows[bidb]; + + if (SCHEDULING_MODE == 0) { + // split work across M + q_steps = (actual_seqlen + STEP_Q - 1) / STEP_Q; + + // Q_steps may be distributed to multiple blocks to increase the occupacy + // when b*h is small. + // The number of q_steps needs to be multiple of 2. + q_steps = (q_steps + gridDim.x - 1) / gridDim.x; + q_steps += (q_steps & 1); + // The last block may process fewer q_steps. + q_step_offset = q_steps * blockIdx.x; + } + + int q_tile_offset = q_step_offset * STEP_Q; + if (q_tile_offset >= actual_seqlen) { + if (SCHEDULING_MODE == 0) { + tile_id_ += gridDim.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + continue; + } + + // Split work across N. + int const kv_steps = (actual_seqlen + STEP_KV - 1) / STEP_KV; + for (int q_step_idx = 0; q_step_idx < q_steps; q_step_idx += 2) { + load_q(bidh, (q_step_idx + 0 + q_step_offset) * STEP_Q, desc_q, shared->smem_q[0], cbw0); + load_q(bidh, (q_step_idx + 1 + q_step_offset) * STEP_Q, desc_q, shared->smem_q[1], cbw1); + + // Q step bound is 2 tiles away at this moment because of 2x1 math warpgroup + int const q_step_end = (q_step_idx + q_step_offset + 2) * STEP_Q - 1; + + // The kv tile idx range for this q step. + auto const [kv_idx_start, kv_idx_end] = compute_kv_tile_idx( + params, (q_step_idx + q_step_offset) * STEP_Q, q_step_end, kv_steps); + + // Iterate over the kv tiles for this q step. + for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++) { + int bar_id = load_kv(bidh / params.h_q_per_kv, kv_step_idx * STEP_KV, desc_k, desc_v, + shared, cbw_k, cbw_v, cbw_v_scratch); + + // Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor + if (q_step_idx == 0 && kv_step_idx == kv_idx_start) { + // Send head info. + typename Shared::Head_info info{ + q_steps, + // q, and kv have the same length. + q_tile_offset, USE_CUSTOM_MASK ? sum_mask_s : q_tile_offset, kv_steps, + // q, and kv have the same length. + actual_seqlen, actual_seqlen, sum_s_q_ * params.h + bidh, bidh, bidb}; + // NOTE(tizheng): The need for the sync after consumer bar wait is to avoid a deadlock + // hazard when DMA thread 0 is ahead of other DMA threads. For example: DMA thread 0 + // have finished consumer bar wait phase 0 and producer bar arrive phase 0, and then + // MMA warps have finished producer bar wait phase 0 and consumer bar arrive phase 1. + // At this time other DMA threads start consumer bar wait phase 0. It will never + // become ready. DMA warps then fail to continue to the next loop. + // + // It is the same consideration for the sync after tmaReserve in load_q and load_kv + // implementation below. + headinfo_tracker0.template push_with_sync( + elect_one_, info); + headinfo_tracker1.template push_with_sync( + elect_one_, info); + } + + if constexpr (DMA_GROUP_TRANSPOSE_V) { + transpose_v_tile(bar_id, shared, cbw_v, cbr_v_scratch); + } + } // kv + } // q + + if (SCHEDULING_MODE == 0) { + tile_id_ += gridDim.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + } // gridDim.y + // Signal compute groups to break. + headinfo_tracker0.template push_with_sync( + elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1}); + headinfo_tracker1.template push_with_sync( + elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1}); + } + + // Support contiguous Q + contiguous/paged KV separate cache. + inline __device__ void run_separate_q_and_kv( + bert::Fused_multihead_attention_params_v2 const& params, Shared* shared) { + // DMA. + int local_wid = (threadIdx.x / 32) % 4; + int tiw = threadIdx.x % 32; + uint32_t smem_tile_id = __cvta_generic_to_shared(&shared->tile_id); + + if (SCHEDULING_MODE == 0) { + tile_id_ = blockIdx.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + + auto cbw0 = shared->tma_q_tracker[0].createWriter(); + auto cbw1 = shared->tma_q_tracker[1].createWriter(); + Circular_buffer_kv_writer cbw_k = shared->tma_k_tracker.createWriter(); + Circular_buffer_kv_writer cbw_v = shared->tma_v_tracker.createWriter(); + Circular_buffer_v_scratch_reader cbr_v_scratch = shared->tma_v_scratch_tracker.createReader(); + Circular_buffer_v_scratch_writer cbw_v_scratch = shared->tma_v_scratch_tracker.createWriter(); + auto headinfo_tracker0 = shared->head_info_tracker[0].createWriter(); + auto headinfo_tracker1 = shared->head_info_tracker[1].createWriter(); + + while (tile_id_ < params.num_tiles) { + // If we do bidh = next_head % h, we'd guarantee b to be spread across CTAs. + + int bidb, tmp, bidh, local_q_tile_offset, q_steps; + + if (SCHEDULING_MODE == 0) { + bidh = tile_id_ % params.h; + bidb = tile_id_ / params.h; + } else if (SCHEDULING_MODE == 1) { + bidb = tile_id_ / (params.h * params.num_tiles_per_head); + tmp = tile_id_ % (params.h * params.num_tiles_per_head); + bidh = tmp / params.num_tiles_per_head; + local_q_tile_offset = (tmp % params.num_tiles_per_head) * NUM_COMPUTE_GROUPS * STEP_Q; + q_steps = NUM_COMPUTE_GROUPS; + } else { // SCHEDULING_MODE == 2 + local_q_tile_offset = (params.num_tiles_per_head - 1 - tile_id_ / (params.b * params.h)) * + NUM_COMPUTE_GROUPS * STEP_Q; + tmp = tile_id_ % (params.b * params.h); + bidh = tmp / params.b; + bidb = tmp % params.b; + q_steps = NUM_COMPUTE_GROUPS; + } + int bidh_kv = bidh / params.h_q_per_kv; + + // Sequence length parameters. + // Take chunked attention (q, and kv may have difference sequence length) into + // consideration. + sum_s_q_ = params.is_s_padded ? bidb * params.s : params.cu_q_seqlens[bidb]; + sum_s_kv_ = params.is_s_padded ? bidb * params.s : params.cu_kv_seqlens[bidb]; + int actual_q_seqlen = params.cu_q_seqlens[bidb + 1] - params.cu_q_seqlens[bidb]; + int actual_kv_seqlen = params.cu_kv_seqlens[bidb + 1] - params.cu_kv_seqlens[bidb]; + int past_kv_length = actual_kv_seqlen - actual_q_seqlen; + + // The cumulative packed_mask seqlens. + // Each sequence length in the batch has to be padded to multiple of 128. + int sum_mask_s = params.cu_mask_rows[bidb]; + + // Prepare the tma descriptors. + cudaTmaDesc const* desc_q = ¶ms.tma_desc_q; + cudaTmaDesc const* desc_k = ¶ms.tma_desc_k; + cudaTmaDesc const* desc_v = ¶ms.tma_desc_v; + + int32_t const* paged_block_offsets = + params.paged_kv_cache.mBlockOffsets + bidb * 2 * params.paged_kv_cache.mMaxBlocksPerSeq; + + if (SCHEDULING_MODE == 0) { + // split work across M + q_steps = (actual_q_seqlen + STEP_Q - 1) / STEP_Q; + + // Q_steps may be distributed to multiple blocks to increase the occupacy + // when b*h is small. + // The number of q_steps needs to be multiple of 2. + q_steps = (q_steps + gridDim.x - 1) / gridDim.x; + q_steps += (q_steps & 1); + local_q_tile_offset = q_steps * blockIdx.x * STEP_Q; + } + + // The last block may process fewer q_steps. + if (local_q_tile_offset >= actual_q_seqlen) { + if (SCHEDULING_MODE == 0) { + tile_id_ += gridDim.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + continue; + } + + // The global q tile offset which includes the past kv cache. + int q_tile_offset = local_q_tile_offset + past_kv_length; + // Split work across N. + int const kv_steps = (actual_kv_seqlen + STEP_KV - 1) / STEP_KV; + // Page KV: number of valid kv blocks (others might be nullptr). + int const num_valid_kv_blocks = + (actual_kv_seqlen + params.paged_kv_cache.mTokensPerBlock - 1) >> + params.paged_kv_cache.mTokensPerBlockLog2; + + for (int q_step_idx = 0; q_step_idx < q_steps && actual_kv_seqlen > 0; q_step_idx += 2) { + load_q(bidh, q_step_idx * STEP_Q + local_q_tile_offset, desc_q, shared->smem_q[0], cbw0); + load_q(bidh, (q_step_idx + 1) * STEP_Q + local_q_tile_offset, desc_q, shared->smem_q[1], + cbw1); + + // Q step end is 2 tiles away at this moment because of 2x1 math warpgroup + int const q_step_end = (q_step_idx + 2) * STEP_Q - 1 + q_tile_offset; + + // The kv tile idx range for this q step. + auto const [kv_idx_start, kv_idx_end] = compute_kv_tile_idx( + params, q_step_idx * STEP_Q + q_tile_offset, q_step_end, kv_steps); + + // Iterate over the kv tiles for this q step. + for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++) { + // The barrier id. + int bar_id; + // Load paged kv input. + if constexpr (PAGED_KV_INPUT) { + bar_id = load_paged_kv(bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks, + params.paged_kv_cache.mTokensPerBlockLog2, + params.blocks_per_tma_load, params.blocks_per_tma_load_log2, + params.paged_kv_cache.mMaxBlocksPerSeq, paged_block_offsets, + desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch); + } else { + bar_id = load_kv(bidh_kv, kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k, cbw_v, + cbw_v_scratch); + } + + // Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor + if (q_step_idx == 0 && kv_step_idx == kv_idx_start) { + // Send head info. + typename Shared::Head_info info{q_steps, + local_q_tile_offset, + USE_CUSTOM_MASK ? sum_mask_s : q_tile_offset, + kv_steps, + actual_q_seqlen, + actual_kv_seqlen, + sum_s_q_ * params.h + bidh, + bidh, + bidb}; + headinfo_tracker0.template push_with_sync( + elect_one_, info); + headinfo_tracker1.template push_with_sync( + elect_one_, info); + } + if constexpr (DMA_GROUP_TRANSPOSE_V) { + transpose_v_tile(bar_id, shared, cbw_v, cbr_v_scratch); + } + } // kv + } // q + + if (SCHEDULING_MODE == 0) { + tile_id_ += gridDim.y; + } else { + get_next_tile_id(local_wid, tiw, smem_tile_id, params.tile_id_counter_ptr); + } + } // gridDim.y + + // Signal compute groups to break. + headinfo_tracker0.template push_with_sync( + elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1}); + headinfo_tracker1.template push_with_sync( + elect_one_, {-1, -1, -1, -1, -1, -1, -1, -1}); + } + + // Load q tiles from gmem to smem by TMA. + template + inline __device__ void load_q(int bidh, int q_tile_start_offset, cudaTmaDesc const* desc_q, + Smem_q& smem_q, BufferWriter& cbw) { + int barrier_id = cbw.tmaReserve(elect_one_, TILE_SIZE_Q * Kernel_traits::ELEMENT_BYTES); + + named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP); + + // split D into multiple groups in order to satisfy the TMA 128B sizzle mode +#pragma unroll + for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) { + const int32_t coords[3] = {di * Kernel_traits::D_PER_GROUP, bidh, + sum_s_q_ + q_tile_start_offset}; + fmha::utmaldg<3, fmha::cudaTmaDescType::TILED, false>( + desc_q, + __cvta_generic_to_shared( + &smem_q[barrier_id * TILE_SIZE_Q + di * TILE_SIZE_Q_PER_D_GROUP]), + __cvta_generic_to_shared(cbw.barrier_ptr(barrier_id)), coords, elect_one_); + } + } + +#define PREPARE_KV_BUFFER() \ + int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) * Kernel_traits::ELEMENT_BYTES); \ + \ + int v_barrier_id; \ + void* v_barrier_ptr; \ + typename Kernel_traits::Element_data_type* v_smem; \ + \ + if constexpr (DMA_GROUP_TRANSPOSE_V) { \ + v_barrier_id = \ + cbw_v_scratch.tmaReserve(elect_one_, (TILE_SIZE_V) * Kernel_traits::ELEMENT_BYTES); \ + v_barrier_ptr = cbw_v_scratch.barrier_ptr(v_barrier_id); \ + v_smem = shared->smem_v_scratch.data(); \ + } else { \ + v_barrier_id = cbw_v.tmaReserve(elect_one_, (TILE_SIZE_V) * Kernel_traits::ELEMENT_BYTES); \ + v_barrier_ptr = cbw_v.barrier_ptr(v_barrier_id); \ + v_smem = shared->smem_v.data(); \ + } \ + \ + named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP); + + // Load k,v tiles from gmem to smem by TMA. + template + inline __device__ int load_kv(int bidh_kv, int kv_tile_start_offset, cudaTmaDesc const* desc_k, + cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, + BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch) { + PREPARE_KV_BUFFER() + + // split D into multiple groups in order to satisfy the TMA 128B sizzle mode +#pragma unroll + for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) { + const int32_t k_coords[3] = {di * Kernel_traits::D_PER_GROUP, bidh_kv, + sum_s_kv_ + kv_tile_start_offset}; + + fmha::utmaldg<3, fmha::cudaTmaDescType::TILED, false>( + desc_k, + __cvta_generic_to_shared( + &shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP]), + __cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_); + } + +#pragma unroll + for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di) { + const int32_t v_coords[3] = {di * Kernel_traits::D_PER_GROUP, bidh_kv, + sum_s_kv_ + kv_tile_start_offset}; + + fmha::utmaldg<3, fmha::cudaTmaDescType::TILED, false>( + desc_v, + __cvta_generic_to_shared( + &v_smem[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]), + __cvta_generic_to_shared(v_barrier_ptr), v_coords, elect_one_); + } + + return v_barrier_id; + } + + // Load paged k,v tiles from gmem to smem by TMA. + template + inline __device__ int load_paged_kv(int bidh_kv, int kv_tile_start_offset, + int num_valid_kv_blocks, int tokens_per_block_log2, + int blocks_per_tma_load, int blocks_per_tma_load_log2, + int max_blocks_per_sequence, + int32_t const* paged_block_offsets, + cudaTmaDesc const* desc_k, cudaTmaDesc const* desc_v, + Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, + BufferWriterScratch& cbw_v_scratch) { + PREPARE_KV_BUFFER() + + // Paged KV cache block idx. + int paged_kv_block_idx = kv_tile_start_offset >> tokens_per_block_log2; + int kv_offset_in_block = kv_tile_start_offset & ((1 << tokens_per_block_log2) - 1); + + // coordinates: d, s, h, 1 + int const tile_size_k_per_block = TILE_SIZE_K_PER_D_GROUP >> blocks_per_tma_load_log2; + static_assert(TILE_SIZE_V_PER_D_GROUP == TILE_SIZE_K_PER_D_GROUP, + "KV tile should have the same tensor size."); + for (int bi = 0; bi < blocks_per_tma_load; ++bi) { + int const bounded_block_idx = min(num_valid_kv_blocks - 1, paged_kv_block_idx + bi); + + const int32_t k_paged_block_offset = paged_block_offsets[bounded_block_idx]; + const int32_t v_paged_block_offset = + paged_block_offsets[max_blocks_per_sequence + bounded_block_idx]; + +#pragma unroll + for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) { + const int32_t k_coords[4] = {di * Kernel_traits::D_PER_GROUP, kv_offset_in_block, bidh_kv, + k_paged_block_offset}; + + fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>( + desc_k, + __cvta_generic_to_shared( + &shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP + + bi * tile_size_k_per_block]), + __cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_); + } + +#pragma unroll + for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di) { + const int32_t v_coords[4] = {di * Kernel_traits::D_PER_GROUP, kv_offset_in_block, bidh_kv, + v_paged_block_offset}; + + fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>( + desc_v, + __cvta_generic_to_shared( + &v_smem[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP + + bi * tile_size_k_per_block]), + __cvta_generic_to_shared(v_barrier_ptr), v_coords, elect_one_); + } + } + + return v_barrier_id; + } + + template + // Transpose v tile explicitly as QGMMA doesn't support it. + inline __device__ void transpose_v_tile(int v_scratch_barrier_id, Shared* shared, + BufferWriter& cbw_v, + BufferReaderScratch& cbr_v_scratch) { + static_assert(NUM_THREADS_IN_DMA_GROUP == 128, ""); + Transposer transposer(threadIdx.x % NUM_THREADS_IN_DMA_GROUP); + + // Src buffer available + int ready = cbr_v_scratch.peek(); + if (!ready) { + cbr_v_scratch.wait(); + } + uint32_t smem_v_src = __cvta_generic_to_shared(&shared->smem_v_scratch[v_scratch_barrier_id]); + + // Dst buffer available + int v_barrier_id = cbw_v.threadReserve(); + uint32_t smem_v_dst = __cvta_generic_to_shared(&shared->smem_v[v_barrier_id * TILE_SIZE_V]); + +// Explicitly transpose the v buffer in smem for fp8. + +// The transposer currently has support of the following tile sizes: +// - D=32, S (or KV_STEP)=128 +// - D=64, S (or KV_STEP)=64, 128 +// - D=128, S (or KV_STEP)=64, 128 +// In addition, the transposer can only work with contiguous chunk of SMEM. +// +// For example, if V tile size is D=256 S=256, we can divide the TMA load of the V tile +// (SxD) into 2x2 chunks of size 128x128. This way, when tiles (0, 0), (0, 1) are transposed, +// either the load and the store of the data can be performed in a contiguous memory. +// +// Keep in mind in order to match GMMA requirement, we need to store the transposed tiles +// along D dim first then S dim. Leading dimension S after the transpose is at most 128B. +// +// Logical: +// D - D I M (contiguous) +// +// 128 128 S +// <------------> <------------> - +// s, d = (0, 0) | s, d = (0, 1) D +// ------------------------------ I +// s, d = (1, 0) | s, d = (1, 1) M +// +// In SMEM: +// D - D I M +// +// 128 128 128 128 S +// <------------> <-------------> <-------------> <------------> - +// s, d = (0, 0) | s, d = (0, 1) | s, d = (1, 0) | s, d = (1, 1) D (contiguous) +// I +// M +// +#pragma unroll + for (int kgroup_idx = 0; kgroup_idx < Kernel_traits::BMM2_K_GROUPS; kgroup_idx++) { +#pragma unroll + for (int dgroup_idx = 0; dgroup_idx < Kernel_traits::DV_GROUPS; dgroup_idx++) { + // Src smem block is k first then d + uint32_t src_offset = + (kgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::D_PER_GROUP + + dgroup_idx * Kernel_traits::D_PER_GROUP * Kernel_traits::STEP_KV) * + Kernel_traits::ELEMENT_BYTES; + + // Dst smem block is d first then k + uint32_t dst_offset = + (dgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::D_PER_GROUP + + kgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::DV) * + Kernel_traits::ELEMENT_BYTES; + + transposer.template transpose_(smem_v_src + src_offset, smem_v_dst + dst_offset); + } + } + + fence_view_async_shared(); // Commit STSM + named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP); // Sync before signaling + cbw_v.threadCommit(elect_one_, v_barrier_id); // Signal readiness + cbr_v_scratch.pop(elect_one_); // Advance to next phase + } + + inline __device__ void get_next_tile_id(int local_wid, int tiw, uint32_t smem_tile_id, + uint32_t* tile_id_counter_ptr) { + if constexpr (DMA_GROUP_TRANSPOSE_V) { + if (elect_one_) { + tile_id_ = atomicAdd(tile_id_counter_ptr, 1); + sts(smem_tile_id, tile_id_); + } + fence_view_async_shared(); + named_barrier_wait(SYNC_BARRIER, 128); + if (tiw == 0) { + lds(tile_id_, smem_tile_id); + } + tile_id_ = __shfl_sync(0xffffffff, tile_id_, 0); + // only one warp involved when the dma group doesn't need to transpose the v tile. + } else { + if (elect_one_) { + tile_id_ = atomicAdd(tile_id_counter_ptr, 1); + } + tile_id_ = __shfl_sync(0xffffffff, tile_id_, 0); + } + } + }; + + //////////////////////////////////////////////////////////////////////////////////////////////// + + struct Host { + Host() {} + + // Set TMA descriptors on host, and launch as __grid_constant__. + // Paged KV FMHA parameters. + void init_params(bert::Fused_multihead_attention_params_v2& params, + bert::Fused_multihead_attention_launch_params const& launch_params, + cudaStream_t stream) const { + const uint32_t d = params.d; + const uint32_t dv = params.dv; + const uint32_t h = params.h; + const uint32_t h_kv = params.h_kv; + + // Total sequence length. + const uint32_t total_seqlen = + params.is_s_padded ? (params.b * params.s) : launch_params.total_q_seqlen; + + // O Layout: [total_seqlen, H, DV] + // Per batch tensor size. + uint32_t tensor_size_o[3] = {dv, h, total_seqlen}; + + // Stride size in bytes. Assumes least significant dim is 1 + uint64_t tensor_stride_o[2] = {dv * Kernel_traits::ELEMENT_BYTES, + uint64_t(params.o_stride_in_bytes)}; + + // Starting memory address + char* o_ptr = reinterpret_cast(params.o_ptr); + + // Box size of TMA + uint32_t box_size_o[3] = {Kernel_traits::D_PER_GROUP, 1, 16}; + + // Traversal stride. + uint32_t traversal_stride[3] = {1, 1, 1}; + + // OOB fill zeros. + uint32_t oob_fill = 0; + + // FP32 to TF32 conversion disabled. + uint32_t fp32_to_tf32 = 0; + + // GMMA descriptor mode. + static constexpr int D_BYTES_PER_GROUP = Kernel_traits::D_BYTES_PER_GROUP; + static constexpr fmha::cudaTmaDescSwizzle swizzle_mode = + (D_BYTES_PER_GROUP > 64 ? fmha::cudaTmaDescSwizzle::SWIZZLE_128B + : D_BYTES_PER_GROUP > 32 ? fmha::cudaTmaDescSwizzle::SWIZZLE_64B + : fmha::cudaTmaDescSwizzle::SWIZZLE_32B); + + static_assert(STEP_KV <= 256 && STEP_Q <= 256, "max box size is 256"); + + // Desc Format (data type). + static constexpr fmha::cudaTmaDescFormat desc_format = (Kernel_traits::ELEMENT_BYTES == 1) + ? fmha::cudaTmaDescFormat::U8 + : fmha::cudaTmaDescFormat::F16_RN; + + fmha::Multiple_tma_descriptor<3> qo_tma_descriptor; + + // TMA O + if (Kernel_traits::USE_TMA_STORE) { + qo_tma_descriptor.set_tma_desctriptor( + o_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_o, tensor_stride_o, + traversal_stride, box_size_o, oob_fill, fp32_to_tf32, ¶ms.tma_desc_o); + } + + auto const layout = launch_params.attention_input_layout; + + // Q always uses 3D tensor + uint32_t tensor_size_q[3] = {d, h, total_seqlen}; + + uint64_t tensor_stride_q[2] = {d * Kernel_traits::ELEMENT_BYTES, + uint64_t(params.q_stride_in_bytes)}; + + char* q_ptr = reinterpret_cast( + layout == fmha::Attention_input_layout::PACKED_QKV ? params.qkv_ptr : params.q_ptr); + + uint32_t box_size_q[3] = {Kernel_traits::D_PER_GROUP, 1, STEP_Q}; + + if (layout == fmha::Attention_input_layout::Q_PAGED_KV) { + // KV in q_paged_kv uses 4D tensor + // Layout: [INT32_MAX, H_KV, TokensPerBlock, D] + const uint32_t tokens_per_block = params.paged_kv_cache.mTokensPerBlock; + uint32_t tensor_size_k[4] = {d, tokens_per_block, h_kv, INT_MAX}; + uint32_t tensor_size_v[4] = {dv, tokens_per_block, h_kv, INT_MAX}; + + uint64_t tensor_stride_k[3]; + tensor_stride_k[0] = params.k_stride_in_bytes / tokens_per_block; // d + tensor_stride_k[1] = params.k_stride_in_bytes; // d * 64 + tensor_stride_k[2] = params.paged_kv_cache.mBytesPerBlock; + uint64_t tensor_stride_v[3]; + // we cannot use dv * Kernel_traits::ELEMENT_BYTES because V may be padded (MLA) + tensor_stride_v[0] = params.v_stride_in_bytes / tokens_per_block; // dv + tensor_stride_v[1] = params.v_stride_in_bytes; // dv * 64 + tensor_stride_v[2] = params.paged_kv_cache.mBytesPerBlock; + + char* kv_ptr = reinterpret_cast(params.paged_kv_cache.mPoolPtr); + + uint32_t box_size_kv[4] = {Kernel_traits::D_PER_GROUP, + std::min(tokens_per_block, STEP_KV), 1, 1}; + + assert(STEP_KV % tokens_per_block == 0 || tokens_per_block % STEP_KV == 0); + params.blocks_per_tma_load = std::max(1, STEP_KV / tokens_per_block); + params.blocks_per_tma_load_log2 = log2(params.blocks_per_tma_load); + + uint32_t traversal_stride[4] = {1, 1, 1, 1}; + + fmha::Multiple_tma_descriptor<4> kv_tma_descriptor; + // K + kv_tma_descriptor.set_tma_desctriptor( + kv_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_k, tensor_stride_k, + traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_k); + // V + kv_tma_descriptor.set_tma_desctriptor( + kv_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_v, tensor_stride_v, + traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_v); + } else { + // Otherwise KV uses 3D tensor + uint32_t tensor_size_k[3] = {d, h_kv, total_seqlen}; + uint32_t tensor_size_v[3] = {dv, h_kv, total_seqlen}; + + uint64_t tensor_stride_k[2] = {d * Kernel_traits::ELEMENT_BYTES, + uint64_t(params.k_stride_in_bytes)}; + uint64_t tensor_stride_v[2] = {dv * Kernel_traits::ELEMENT_BYTES, + uint64_t(params.v_stride_in_bytes)}; + + uint32_t box_size_kv[3] = {Kernel_traits::D_PER_GROUP, 1, STEP_KV}; + + char *k_ptr, *v_ptr; + + if (layout == fmha::Attention_input_layout::PACKED_QKV) { + if (!HEADS_INTERLEAVED || h != h_kv) { + // Layout: [total_seqlen, (H, D) + (H_KV, D) + (H_KV, DV)] + // All of MHA in TRTLLM is in this layout, + // and MQA/GQA must use this layout. + k_ptr = q_ptr + h * d * Kernel_traits::ELEMENT_BYTES; + v_ptr = k_ptr + h_kv * d * Kernel_traits::ELEMENT_BYTES; + } else { + // Layout: [total_seqlen, H, D + D + DV] + // Currently only used in MHA in fmha_v2 tests. + tensor_stride_q[0] = tensor_stride_k[0] = tensor_stride_v[0] = + (2 * d + dv) * Kernel_traits::ELEMENT_BYTES; + k_ptr = q_ptr + d * Kernel_traits::ELEMENT_BYTES; + v_ptr = k_ptr + d * Kernel_traits::ELEMENT_BYTES; + } + } else if (layout == fmha::Attention_input_layout::CONTIGUOUS_Q_KV) { + k_ptr = reinterpret_cast(params.kv_ptr); + v_ptr = k_ptr + h_kv * d * Kernel_traits::ELEMENT_BYTES; + } else if (layout == fmha::Attention_input_layout::SEPARATE_Q_K_V) { + k_ptr = reinterpret_cast(params.k_ptr); + v_ptr = reinterpret_cast(params.v_ptr); + } + + fmha::Multiple_tma_descriptor<3> kv_tma_descriptor; + // K + kv_tma_descriptor.set_tma_desctriptor( + k_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_k, tensor_stride_k, + traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_k); + // V + kv_tma_descriptor.set_tma_desctriptor( + v_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_v, tensor_stride_v, + traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_v); + } + // Q + qo_tma_descriptor.set_tma_desctriptor( + q_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_q, tensor_stride_q, + traversal_stride, box_size_q, oob_fill, fp32_to_tf32, ¶ms.tma_desc_q); + } + }; +}; + +} // namespace ws +} // namespace fmha diff --git a/csrc/fmha_v2/fmha/warpspec/epilogue.h b/csrc/fmha_v2/fmha/warpspec/epilogue.h new file mode 100644 index 0000000000..15f8636207 --- /dev/null +++ b/csrc/fmha_v2/fmha/warpspec/epilogue.h @@ -0,0 +1,1091 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include +#include +#include + +namespace fmha { +namespace ws { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Special Softmax struct to handle optimization tricks on Hopper Warp-Specialized Kernels. +template