diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 89fe53b874..8df6066391 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -26,6 +26,7 @@ #include #include +#include "tvm/ffi/error.h" #include "tvm_ffi_utils.h" using tvm::ffi::Optional; @@ -163,6 +164,9 @@ void trtllm_paged_attention_launcher( use_multi_block ? TileScheduler::Static : TileScheduler::Persistent; runner_params.mMultiCtasKvMode = use_multi_block; + runner_params.cumSeqLensQPtr = cum_seq_lens_q; + runner_params.cumSeqLensKvPtr = nullptr; + size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = @@ -213,7 +217,8 @@ void trtllm_paged_attention_decode( TensorView seq_lens, int64_t max_kv_len, Variant bmm1_scale, Variant bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, - bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { + bool enable_pdl, int64_t workspace_size, Optional attention_sinks, + Optional optional_max_q_len, Optional cum_seq_lens_q) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -221,14 +226,41 @@ void trtllm_paged_attention_decode( TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i)); } auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); - // NOTE(Zihao): query is [B, Q, H, D] - // where Q is the number of query tokens per request, used in MTP - // based on profiled results, always use decode mode for MTP (q_len is small) - // example: when kv_len = 10000, q < 200, decode mode is faster - int batch_size = query.size(0); - int q_len_per_request = query.size(1); - int sum_seq_q = batch_size * q_len_per_request; - int num_qo_heads = query.size(2); + int batch_size; + int max_q_len; + int sum_seq_q; + int num_qo_heads; + int* cum_seq_lens_q_ptr = nullptr; + if (!optional_max_q_len.has_value()) { + // each request has the same length + + // NOTE(Zihao): query is [B, Q, H, D] + // where Q is the number of query tokens per request, used in MTP + // based on profiled results, always use decode mode for MTP (q_len is small) + // example: when kv_len = 10000, q < 200, decode mode is faster + TVM_FFI_CHECK(query.ndim() == 4, + "When max_q_len is not provided, query must be of shape [batch_size, q_len, " + "num_qo_heads, head_dim_q]"); + int q_len_per_request = query.size(1); + batch_size = query.size(0); + sum_seq_q = batch_size * q_len_per_request; + num_qo_heads = query.size(2); + max_q_len = q_len_per_request; + } else { + // each request has different length + TVM_FFI_CHECK(cum_seq_lens_q.has_value(), + "cum_seq_lens_q must be provided when max_q_len is provided"); + TVM_FFI_CHECK( + query.ndim() == 3, + "When max_q_len is provided, query must be of shape [sum_seq_q, num_qo_heads, head_dim_q]"); + // the shape of query: [sum_seq_q, num_qo_heads, head_dim_q] + // the shape of cum_seq_lens_q: [batch_size + 1] + batch_size = cum_seq_lens_q.value().size(0) - 1; + sum_seq_q = query.size(0); + num_qo_heads = query.size(1); + max_q_len = optional_max_q_len.value(); + cum_seq_lens_q_ptr = static_cast(cum_seq_lens_q.value().data_ptr()); + } // Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even. int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1); int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1); @@ -284,15 +316,13 @@ void trtllm_paged_attention_decode( trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), - static_cast(seq_lens.data_ptr()), - /*cum_seq_lens_q=*/nullptr, - /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, - TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len, - num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, - kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, - bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, - o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count, - enable_pdl, workspace_size, stream); + static_cast(seq_lens.data_ptr()), cum_seq_lens_q_ptr, + /*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, + TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, + num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, + kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, + bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, + sum_seq_q, sparse_mla_top_k, sm_count, enable_pdl, workspace_size, stream); } void trtllm_paged_attention_context( diff --git a/flashinfer/decode.py b/flashinfer/decode.py index cc865ae5f8..23e97dea1f 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1916,6 +1916,8 @@ def _paged_run( enable_pdl, workspace_size, sinks, + None, # max_q_len + None, # cum_seq_lens_q ) return out @@ -2077,12 +2079,14 @@ def trtllm_batch_decode_with_kv_cache( q_len_per_req: Optional[int] = 1, o_scale: Optional[float] = 1.0, mask: Optional[torch.Tensor] = None, + max_q_len: Optional[int] = None, + cum_seq_lens_q: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters ---------- query : torch.Tensor - query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request + query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch. kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``, @@ -2150,6 +2154,16 @@ def trtllm_batch_decode_with_kv_cache( mask : Optional[torch.Tensor] = None causal attention mask for xqa speculative decoding. + max_q_len: Optional[int] = None + The maximum query sequence length across all requests when using variable-length queries. + Only supported by trtllm-gen backend. Must be provided together with ``cum_seq_lens_q``. + When None, all requests use uniform query length specified by ``q_len_per_req``. + + cum_seq_lens_q : Optional[torch.Tensor] = None + Cumulative query sequence lengths for variable-length query support, shape: ``[batch_size + 1]``, dtype: ``torch.int32``. + Only supported by trtllm-gen backend. Must be provided together with ``max_q_len``. + When None, all requests use uniform query length specified by ``q_len_per_req``. + Returns ------- out : Union[torch.Tensor, FP4Tensor] @@ -2181,6 +2195,8 @@ def trtllm_batch_decode_with_kv_cache( raise ValueError("xqa backend does not support nvfp4 output") if o_sf_scale is not None or o_sf_vec_size is not None: raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size") + if max_q_len is not None or cum_seq_lens_q is not None: + raise ValueError("xqa backend does not support cum_seq_lens_q") # Handle out and out_dtype if out_dtype is None: @@ -2305,7 +2321,9 @@ def trtllm_batch_decode_with_kv_cache( q_len_per_req, query.size(1), query.size(2), - ), + ) + if q_len_per_req is not None + else query, k_cache, v_cache, workspace_buffer, @@ -2323,6 +2341,8 @@ def trtllm_batch_decode_with_kv_cache( enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), sinks, + max_q_len, + cum_seq_lens_q, ) return ( diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index dd0002ff06..89a4d1a973 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -54,8 +54,13 @@ def generate_seq_lens_prefill(batch_size, max_q_len, max_in_kv_len): return q_lens, in_kv_lens, seq_lens -def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len): - q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32) +def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len, max_q_len): + if q_len_per_req is not None: + assert max_q_len is None, "Can not specify both q_len_per_req and max_q_len." + q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32) + else: + assert max_q_len is not None, "Must specify either q_len_per_req or max_q_len." + q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int) in_kv_lens[-1] = max_in_kv_len seq_lens = q_lens + in_kv_lens @@ -746,6 +751,7 @@ def _test_trtllm_batch_decode( max_in_kv_len, head_dim, device_scale=False, + max_q_len=None, ): """ Common function for testing trtllm-gen decode. @@ -767,9 +773,14 @@ def _test_trtllm_batch_decode( if backend == "xqa" and q_dtype == "fp8": pytest.skip("xqa backend only supports fp16 and bf16 query") - if o_dtype == "nvfp4" and q_len_per_req > 1: + if o_dtype == "nvfp4" and ( + q_len_per_req is not None + and q_len_per_req > 1 + or max_q_len is not None + and max_q_len > 1 + ): # todo(Yingyi): add support for nvfp4 with speculative decoding - pytest.skip("nvfp4 is not supported for q_len_per_req > 1") + pytest.skip("nvfp4 is not supported for q_len_per_req > 1 or max_q_len > 1 yet") if backend == "trtllm-gen" and o_dtype == "fp8" and q_dtype != "fp8": pytest.skip("trtllm-gen backend only supports fp8 output for fp8 query") @@ -780,7 +791,7 @@ def _test_trtllm_batch_decode( # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode( - batch_size, q_len_per_req, max_in_kv_len + batch_size, q_len_per_req, max_in_kv_len, max_q_len ) # Create query tensor and related data @@ -835,7 +846,7 @@ def _test_trtllm_batch_decode( "window_left": window_left, } if not enable_sink: - if q_len_per_req == 1: + if q_len_per_req is not None and q_len_per_req == 1: wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer_ref, kv_layout, use_tensor_cores=True ) @@ -886,7 +897,8 @@ def _test_trtllm_batch_decode( kv_indptr=kv_indptr_tokens, ) - if q_len_per_req > 1: + if q_len_per_req and q_len_per_req > 1: + # only used for xqa speculative decoding mask = generate_causal_mask(batch_size, q_len_per_req, GPU_DEVICE) else: mask = None @@ -923,6 +935,8 @@ def _test_trtllm_batch_decode( q_len_per_req=q_len_per_req, o_scale=o_scale, mask=mask, + max_q_len=max_q_len if max_q_len is not None else None, + cum_seq_lens_q=q_indptr if max_q_len is not None else None, ) if backend == "trtllm-gen": # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero @@ -948,7 +962,7 @@ def _test_trtllm_batch_decode( # convert to float32 for fp8 is not supported by assert_close # relax rtol and atol for speculative decoding test - if q_len_per_req > 1: + if (q_len_per_req and q_len_per_req > 1) or (max_q_len and max_q_len > 1): rtol, atol = rtol * 2, atol * 2 # Arbitary small mismatch rate @@ -967,7 +981,10 @@ def _test_trtllm_batch_decode( # Only test wrapper with trtllm-gen backend if ( - o_dtype != "nvfp4" and backend == "trtllm-gen" + o_dtype != "nvfp4" + and backend == "trtllm-gen" + and q_len_per_req + is not None # only test for the case all requests have the same q_len ): # wrapper api does not support fp4 output yet. # test wrapper with trtllm-gen backend wrapper_trtllm_gen = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( @@ -1434,3 +1451,84 @@ def test_trtllm_gen_prefill_deepseek_bs1( test_trtllm_gen_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal ) + + +@pytest.mark.parametrize("backend", ["trtllm-gen"]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize( + "batch_size,max_q_len,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 2, 1), + (4, 1, 32, 2, 5), + (4, 2, 64, 2, 5), + (4, 3, 32, 2, 5), + (4, 3, 64, 2, 1), + (4, 4, 64, 4, 1), + (4, 5, 64, 4, 8), + (128, 1, 64, 2, 5), + (128, 2, 32, 4, 1), + (128, 3, 16, 4, 8), + (128, 4, 16, 2, 5), + (128, 5, 16, 2, 5), + (256, 1, 64, 4, 8), + (256, 2, 16, 2, 8), + (256, 3, 64, 4, 5), + (256, 4, 32, 2, 8), + (256, 5, 32, 2, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1, 127]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("bf16", "fp8", "bf16"), + ("fp16", "fp8", "fp16"), + ("bf16", "fp8", "fp8"), + ("fp16", "fp8", "fp8"), + ("fp8", "fp8", "bf16"), + ("fp8", "fp8", "fp16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("head_dim", [128]) +def test_trtllm_batch_decode_spec( + backend, + kv_layout, + batch_size, + max_q_len, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + _test_trtllm_batch_decode( + backend, + kv_layout, + batch_size, + None, # q_len_per_req + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + max_q_len=max_q_len, + )