Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 48 additions & 18 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <sstream>
#include <unordered_map>

#include "tvm/ffi/error.h"
#include "tvm_ffi_utils.h"

using tvm::ffi::Optional;
Expand Down Expand Up @@ -163,6 +164,9 @@ void trtllm_paged_attention_launcher(
use_multi_block ? TileScheduler::Static : TileScheduler::Persistent;
runner_params.mMultiCtasKvMode = use_multi_block;

runner_params.cumSeqLensQPtr = cum_seq_lens_q;
runner_params.cumSeqLensKvPtr = nullptr;

size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
size_t num_semaphores =
Expand Down Expand Up @@ -213,22 +217,50 @@ void trtllm_paged_attention_decode(
TensorView seq_lens, int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<int64_t> optional_max_q_len, Optional<TensorView> cum_seq_lens_q) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
for (int i = 0; i < key_cache.ndim(); i++) {
TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i));
}
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
// NOTE(Zihao): query is [B, Q, H, D]
// where Q is the number of query tokens per request, used in MTP
// based on profiled results, always use decode mode for MTP (q_len is small)
// example: when kv_len = 10000, q < 200, decode mode is faster
int batch_size = query.size(0);
int q_len_per_request = query.size(1);
int sum_seq_q = batch_size * q_len_per_request;
int num_qo_heads = query.size(2);
int batch_size;
int max_q_len;
int sum_seq_q;
int num_qo_heads;
int* cum_seq_lens_q_ptr = nullptr;
if (!optional_max_q_len.has_value()) {
// each request has the same length

// NOTE(Zihao): query is [B, Q, H, D]
// where Q is the number of query tokens per request, used in MTP
// based on profiled results, always use decode mode for MTP (q_len is small)
// example: when kv_len = 10000, q < 200, decode mode is faster
TVM_FFI_CHECK(query.ndim() == 4,
"When max_q_len is not provided, query must be of shape [batch_size, q_len, "
"num_qo_heads, head_dim_q]");
int q_len_per_request = query.size(1);
batch_size = query.size(0);
sum_seq_q = batch_size * q_len_per_request;
num_qo_heads = query.size(2);
max_q_len = q_len_per_request;
} else {
// each request has different length
TVM_FFI_CHECK(cum_seq_lens_q.has_value(),
"cum_seq_lens_q must be provided when max_q_len is provided");
TVM_FFI_CHECK(
query.ndim() == 3,
"When max_q_len is provided, query must be of shape [sum_seq_q, num_qo_heads, head_dim_q]");
// the shape of query: [sum_seq_q, num_qo_heads, head_dim_q]
// the shape of cum_seq_lens_q: [batch_size + 1]
batch_size = cum_seq_lens_q.value().size(0) - 1;
sum_seq_q = query.size(0);
num_qo_heads = query.size(1);
max_q_len = optional_max_q_len.value();
cum_seq_lens_q_ptr = static_cast<int*>(cum_seq_lens_q.value().data_ptr());
}
// Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even.
int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1);
int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1);
Expand Down Expand Up @@ -284,15 +316,13 @@ void trtllm_paged_attention_decode(
trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/nullptr,
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count,
enable_pdl, workspace_size, stream);
static_cast<int*>(seq_lens.data_ptr()), cum_seq_lens_q_ptr,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool,
num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values,
kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value,
bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left,
sum_seq_q, sparse_mla_top_k, sm_count, enable_pdl, workspace_size, stream);
}

void trtllm_paged_attention_context(
Expand Down
24 changes: 22 additions & 2 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,6 +1916,8 @@ def _paged_run(
enable_pdl,
workspace_size,
sinks,
None, # max_q_len
None, # cum_seq_lens_q
)
return out

Expand Down Expand Up @@ -2077,12 +2079,14 @@ def trtllm_batch_decode_with_kv_cache(
q_len_per_req: Optional[int] = 1,
o_scale: Optional[float] = 1.0,
mask: Optional[torch.Tensor] = None,
max_q_len: Optional[int] = None,
cum_seq_lens_q: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
----------
query : torch.Tensor
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch.

kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``,
Expand Down Expand Up @@ -2150,6 +2154,16 @@ def trtllm_batch_decode_with_kv_cache(
mask : Optional[torch.Tensor] = None
causal attention mask for xqa speculative decoding.

max_q_len: Optional[int] = None
The maximum query sequence length across all requests when using variable-length queries.
Only supported by trtllm-gen backend. Must be provided together with ``cum_seq_lens_q``.
When None, all requests use uniform query length specified by ``q_len_per_req``.

cum_seq_lens_q : Optional[torch.Tensor] = None
Cumulative query sequence lengths for variable-length query support, shape: ``[batch_size + 1]``, dtype: ``torch.int32``.
Only supported by trtllm-gen backend. Must be provided together with ``max_q_len``.
When None, all requests use uniform query length specified by ``q_len_per_req``.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
Expand Down Expand Up @@ -2181,6 +2195,8 @@ def trtllm_batch_decode_with_kv_cache(
raise ValueError("xqa backend does not support nvfp4 output")
if o_sf_scale is not None or o_sf_vec_size is not None:
raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size")
if max_q_len is not None or cum_seq_lens_q is not None:
raise ValueError("xqa backend does not support cum_seq_lens_q")

# Handle out and out_dtype
if out_dtype is None:
Expand Down Expand Up @@ -2305,7 +2321,9 @@ def trtllm_batch_decode_with_kv_cache(
q_len_per_req,
query.size(1),
query.size(2),
),
)
if q_len_per_req is not None
else query,
k_cache,
v_cache,
workspace_buffer,
Expand All @@ -2323,6 +2341,8 @@ def trtllm_batch_decode_with_kv_cache(
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
max_q_len,
cum_seq_lens_q,
)

return (
Expand Down
116 changes: 107 additions & 9 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,13 @@ def generate_seq_lens_prefill(batch_size, max_q_len, max_in_kv_len):
return q_lens, in_kv_lens, seq_lens


def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len):
q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32)
def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len, max_q_len):
if q_len_per_req is not None:
assert max_q_len is None, "Can not specify both q_len_per_req and max_q_len."
q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32)
else:
assert max_q_len is not None, "Must specify either q_len_per_req or max_q_len."
q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32)
in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int)
in_kv_lens[-1] = max_in_kv_len
seq_lens = q_lens + in_kv_lens
Expand Down Expand Up @@ -746,6 +751,7 @@ def _test_trtllm_batch_decode(
max_in_kv_len,
head_dim,
device_scale=False,
max_q_len=None,
):
"""
Common function for testing trtllm-gen decode.
Expand All @@ -767,9 +773,14 @@ def _test_trtllm_batch_decode(
if backend == "xqa" and q_dtype == "fp8":
pytest.skip("xqa backend only supports fp16 and bf16 query")

if o_dtype == "nvfp4" and q_len_per_req > 1:
if o_dtype == "nvfp4" and (
q_len_per_req is not None
and q_len_per_req > 1
or max_q_len is not None
and max_q_len > 1
):
# todo(Yingyi): add support for nvfp4 with speculative decoding
pytest.skip("nvfp4 is not supported for q_len_per_req > 1")
pytest.skip("nvfp4 is not supported for q_len_per_req > 1 or max_q_len > 1 yet")

if backend == "trtllm-gen" and o_dtype == "fp8" and q_dtype != "fp8":
pytest.skip("trtllm-gen backend only supports fp8 output for fp8 query")
Expand All @@ -780,7 +791,7 @@ def _test_trtllm_batch_decode(
# Generate random sequence lengths
num_qo_heads = num_kv_heads * head_grp_size
q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode(
batch_size, q_len_per_req, max_in_kv_len
batch_size, q_len_per_req, max_in_kv_len, max_q_len
)

# Create query tensor and related data
Expand Down Expand Up @@ -835,7 +846,7 @@ def _test_trtllm_batch_decode(
"window_left": window_left,
}
if not enable_sink:
if q_len_per_req == 1:
if q_len_per_req is not None and q_len_per_req == 1:
wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer_ref, kv_layout, use_tensor_cores=True
)
Expand Down Expand Up @@ -886,7 +897,8 @@ def _test_trtllm_batch_decode(
kv_indptr=kv_indptr_tokens,
)

if q_len_per_req > 1:
if q_len_per_req and q_len_per_req > 1:
# only used for xqa speculative decoding
mask = generate_causal_mask(batch_size, q_len_per_req, GPU_DEVICE)
else:
mask = None
Expand Down Expand Up @@ -923,6 +935,8 @@ def _test_trtllm_batch_decode(
q_len_per_req=q_len_per_req,
o_scale=o_scale,
mask=mask,
max_q_len=max_q_len if max_q_len is not None else None,
cum_seq_lens_q=q_indptr if max_q_len is not None else None,
)
if backend == "trtllm-gen":
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
Expand All @@ -948,7 +962,7 @@ def _test_trtllm_batch_decode(

# convert to float32 for fp8 is not supported by assert_close
# relax rtol and atol for speculative decoding test
if q_len_per_req > 1:
if (q_len_per_req and q_len_per_req > 1) or (max_q_len and max_q_len > 1):
rtol, atol = rtol * 2, atol * 2

# Arbitary small mismatch rate
Expand All @@ -967,7 +981,10 @@ def _test_trtllm_batch_decode(

# Only test wrapper with trtllm-gen backend
if (
o_dtype != "nvfp4" and backend == "trtllm-gen"
o_dtype != "nvfp4"
and backend == "trtllm-gen"
and q_len_per_req
is not None # only test for the case all requests have the same q_len
): # wrapper api does not support fp4 output yet.
# test wrapper with trtllm-gen backend
wrapper_trtllm_gen = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
Expand Down Expand Up @@ -1434,3 +1451,84 @@ def test_trtllm_gen_prefill_deepseek_bs1(
test_trtllm_gen_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
)


@pytest.mark.parametrize("backend", ["trtllm-gen"])
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
@pytest.mark.parametrize(
"batch_size,max_q_len,page_size,num_kv_heads,head_grp_size",
[
(4, 1, 16, 2, 1),
(4, 1, 32, 2, 5),
(4, 2, 64, 2, 5),
(4, 3, 32, 2, 5),
(4, 3, 64, 2, 1),
(4, 4, 64, 4, 1),
(4, 5, 64, 4, 8),
(128, 1, 64, 2, 5),
(128, 2, 32, 4, 1),
(128, 3, 16, 4, 8),
(128, 4, 16, 2, 5),
(128, 5, 16, 2, 5),
(256, 1, 64, 4, 8),
(256, 2, 16, 2, 8),
(256, 3, 64, 4, 5),
(256, 4, 32, 2, 8),
(256, 5, 32, 2, 1),
],
)
@pytest.mark.parametrize("window_left", [-1, 127])
@pytest.mark.parametrize(
"q_dtype,kv_dtype,o_dtype",
[
("bf16", "bf16", "bf16"),
("fp16", "fp16", "fp16"),
("bf16", "fp8", "bf16"),
("fp16", "fp8", "fp16"),
("bf16", "fp8", "fp8"),
("fp16", "fp8", "fp8"),
("fp8", "fp8", "bf16"),
("fp8", "fp8", "fp16"),
("fp8", "fp8", "fp8"),
("fp8", "fp8", "nvfp4"),
],
)
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("enable_sink", [True, False])
@pytest.mark.parametrize("max_in_kv_len", [110])
@pytest.mark.parametrize("head_dim", [128])
def test_trtllm_batch_decode_spec(
backend,
kv_layout,
batch_size,
max_q_len,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
):
_test_trtllm_batch_decode(
backend,
kv_layout,
batch_size,
None, # q_len_per_req
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
max_q_len=max_q_len,
)