-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Deepseek v3.2] Support indexer prefill chunking #25999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Chen Zhang <[email protected]>
total_seq_lens = 0 | ||
if total_seq_lens > 0: | ||
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu))) | ||
print("chunk_seq_ids", chunk_seq_ids, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self reminder: remove this print
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces prefill chunking for the Deepseek v3.2 indexer to manage memory usage by splitting large prefill batches. The changes involve refactoring metadata structures to support chunks and adding logic to split prefill requests based on sequence length.
I've found a critical logic issue in the split_prefill_chunks
function that could lead to creating chunks larger than the specified limit, which would cause an assertion failure. I've provided a suggested fix for this. Additionally, there's a debug print
statement that should be removed.
Regarding your note about not observing chunking in your test run, this is likely because the max_prefill_buffer_size
(currently max_model_len * 2
) is larger than the total sequence length of your test case. You could temporarily lower this value to verify that the chunking logic is triggered correctly with the fix.
chunk_seq_ids = [] | ||
total_seq_lens = 0 | ||
for i in range(reqs_start, len(seq_lens_cpu)): | ||
total_seq_lens += seq_lens_cpu[i] | ||
if total_seq_lens > max_prefill_buffer_size: | ||
chunk_seq_ids.append((reqs_start, i)) | ||
reqs_start = i | ||
total_seq_lens = 0 | ||
if total_seq_lens > 0: | ||
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a critical logic error in the chunk splitting implementation. The current logic resets total_seq_lens
to 0 after creating a chunk, but fails to account for the sequence length of the request that triggered the split. This causes the next chunk to potentially include too many requests, exceeding max_prefill_buffer_size
and triggering the assertion in build_one_prefill_chunk
.
For example, with seq_lens_cpu = [1000, 1000, 1000]
and max_prefill_buffer_size = 1500
, the first chunk correctly contains only the first request. However, the second chunk would be created with the remaining two requests, having a total length of 2000, which is over the limit.
Here is a corrected implementation that properly handles chunk creation and sequence length accounting:
chunk_seq_ids = [] | |
total_seq_lens = 0 | |
for i in range(reqs_start, len(seq_lens_cpu)): | |
total_seq_lens += seq_lens_cpu[i] | |
if total_seq_lens > max_prefill_buffer_size: | |
chunk_seq_ids.append((reqs_start, i)) | |
reqs_start = i | |
total_seq_lens = 0 | |
if total_seq_lens > 0: | |
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu))) | |
chunk_seq_ids = [] | |
total_seq_lens = 0 | |
chunk_start_idx = reqs_start | |
for i in range(reqs_start, len(seq_lens_cpu)): | |
seq_len = seq_lens_cpu[i] | |
if total_seq_lens > 0 and total_seq_lens + seq_len > max_prefill_buffer_size: | |
chunk_seq_ids.append((chunk_start_idx, i)) | |
chunk_start_idx = i | |
total_seq_lens = seq_len | |
else: | |
total_seq_lens += seq_len | |
if total_seq_lens > 0: | |
chunk_seq_ids.append((chunk_start_idx, len(seq_lens_cpu))) |
print("chunk_seq_ids", chunk_seq_ids, | ||
seq_lens_cpu[chunk_seq_ids[0][0]:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces prefill chunking for the Deepseek v3.2 indexer to manage memory usage by splitting large prefill requests into smaller chunks. The core logic is implemented in vllm/v1/attention/backends/mla/indexer.py
with a new split_prefill_chunks
function and modifications to the metadata builder.
My review identifies a critical bug in the split_prefill_chunks
function that can lead to incorrect behavior with empty chunks, especially when a single request exceeds the buffer size. I've also pointed out a problematic assertion in build_one_prefill_chunk
that would cause crashes in such scenarios. I've provided suggestions to fix these issues to ensure the chunking logic is robust.
def split_prefill_chunks(seq_lens_cpu: torch.Tensor, | ||
max_prefill_buffer_size: int, | ||
reqs_start: int) -> list[tuple[int, int]]: | ||
""" | ||
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) | ||
such that the total sequence length of each chunk is less than the | ||
maximum prefill buffer size. | ||
Args: | ||
seq_lens_cpu: The sequence lengths of the prefill requests. | ||
max_prefill_buffer_size: The maximum prefill buffer size. | ||
reqs_start: The start index of the prefill requests. | ||
Returns: | ||
A list of tuples of (reqs_start, reqs_end). | ||
""" | ||
chunk_seq_ids = [] | ||
total_seq_lens = 0 | ||
for i in range(reqs_start, len(seq_lens_cpu)): | ||
cur_seq_len = seq_lens_cpu[i].item() | ||
total_seq_lens += cur_seq_len | ||
if total_seq_lens > max_prefill_buffer_size: | ||
chunk_seq_ids.append((reqs_start, i)) | ||
reqs_start = i | ||
total_seq_lens = cur_seq_len | ||
if total_seq_lens > 0: | ||
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu))) | ||
return chunk_seq_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of split_prefill_chunks
has a bug when handling requests that are larger than max_prefill_buffer_size
. It can lead to creating empty chunks and incorrect chunking, which will cause issues downstream. For example, with seq_lens_cpu=torch.tensor([10])
, max_prefill_buffer_size=5
, and reqs_start=0
, the function returns [(0, 0), (0, 1)]
, where the first chunk is empty.
A corrected implementation should handle this edge case gracefully, typically by placing the large request in its own chunk. It would also be beneficial to add a test case for this scenario.
Here is a suggested implementation that fixes this issue:
def split_prefill_chunks(seq_lens_cpu: torch.Tensor,
max_prefill_buffer_size: int,
reqs_start: int) -> list[tuple[int, int]]:
"""
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
such that the total sequence length of each chunk is less than the
maximum prefill buffer size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests.
max_prefill_buffer_size: The maximum prefill buffer size.
reqs_start: The start index of the prefill requests.
Returns:
A list of tuples of (reqs_start, reqs_end).
"""
chunk_seq_ids = []
total_seq_lens = 0
for i in range(reqs_start, len(seq_lens_cpu)):
cur_seq_len = seq_lens_cpu[i].item()
if total_seq_lens > 0 and total_seq_lens + cur_seq_len > max_prefill_buffer_size:
chunk_seq_ids.append((reqs_start, i))
reqs_start = i
total_seq_lens = 0
total_seq_lens += cur_seq_len
if total_seq_lens > 0:
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
return chunk_seq_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heheda12345 I think the bot has a point; we should probably assert there is no empty chunks
token_start = query_start_loc_cpu[reqs_start].item() | ||
token_end = query_start_loc_cpu[reqs_end].item() | ||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() | ||
assert total_seq_lens < self.max_prefill_buffer_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion will fail if a single prefill request has a sequence length greater than or equal to max_prefill_buffer_size
. The chunking logic should be able to handle such cases by putting the large request into its own chunk, but this assertion prevents that and will cause a crash.
As you noted in the PR description ("I don't have a trace that total_seq_lens < self.max_prefill_buffer_size
"), this is indeed a problematic check. It should be removed to allow processing of single large requests.
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting
chunk_seq_ids = [] | ||
total_seq_lens = 0 | ||
for i in range(reqs_start, len(seq_lens_cpu)): | ||
cur_seq_len = seq_lens_cpu[i].item() | ||
total_seq_lens += cur_seq_len | ||
if total_seq_lens > max_prefill_buffer_size: | ||
chunk_seq_ids.append((reqs_start, i)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prevent prefill chunks from equaling buffer size
The new split_prefill_chunks
only starts a new chunk when the accumulated length is strictly greater than max_prefill_buffer_size
. However, build_one_prefill_chunk
asserts total_seq_lens < self.max_prefill_buffer_size
, so a chunk whose total length is exactly the buffer size slips through the splitter and immediately triggers the assertion. Two prefill requests with length max_model_len
each will produce a chunk of exactly 2 * max_model_len
(the configured buffer), causing an assertion failure and request crash. The split predicate should use >=
or the assertion should accept equality so that valid workloads do not abort.
Useful? React with 👍 / 👎.
Signed-off-by: Chen Zhang <[email protected]>
@codex review |
To use Codex here, create a Codex account and connect to github. |
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
def build_one_prefill_chunk(self, reqs_start, reqs_end, | ||
query_start_loc_cpu, seq_lens_cpu, | ||
block_table): | ||
prefill_query_start_loc = query_start_loc_cpu[ | ||
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start] | ||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( | ||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], | ||
self.device) | ||
token_start = query_start_loc_cpu[reqs_start].item() | ||
token_end = query_start_loc_cpu[reqs_end].item() | ||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() | ||
assert total_seq_lens <= self.max_prefill_buffer_size | ||
cu_seq_lens = torch.cat([ | ||
torch.zeros(1, dtype=torch.int32), | ||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0) | ||
]).to(torch.int32).to(self.device) | ||
return DeepseekV32IndexerPrefillChunkMetadata( | ||
cu_seqlen_ks=cu_seqlen_ks, | ||
cu_seqlen_ke=cu_seqlen_ke, | ||
cu_seq_lens=cu_seq_lens, | ||
total_seq_lens=total_seq_lens, | ||
block_table=block_table[reqs_start:reqs_end], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Store prefill chunk length as int, not tensor
build_one_prefill_chunk
keeps total_seq_lens
as the tensor returned by seq_lens_cpu[reqs_start:reqs_end].sum()
and places it in DeepseekV32IndexerPrefillChunkMetadata
. The downstream indexer (sparse_attn_indexer
in deepseek_v2.py
) uses this field to size temporary buffers (torch.empty([chunk.total_seq_lens, …])
). Because the metadata now contains a 0‑D tensor rather than a Python int
, those allocations will raise TypeError: empty(): argument 'size' must be tuple of ints
whenever a prefill chunk is processed. Convert total_seq_lens
to an int
(e.g. .item()
) before storing it.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense to me! I think this is pretty good solution, nice work!
I think the bot was right about the one comment, an assert could be helpful (even though I know it currently would never be hit; would be nice to have incase someone mucks around with get_max_prefill_buffer_size
)
Signed-off-by: Chen Zhang <[email protected]>
Purpose
Split the prefill to multiple steps, with each step contains a subset of prefill requests. With this approach, we can avoid the large output caused by gather kv cache.
Test Plan
20 shot gsm 8k
Test Result
Not crash. And I'm sure it did some chunking.
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.