-
Notifications
You must be signed in to change notification settings - Fork 101
[WIP] Add Chunked Shared Attention with Dense Biases #784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Konrad Zawora <[email protected]>
Signed-off-by: Konrad Zawora <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces chunked processing for shared blocks in unified attention to address memory issues when handling large numbers of shared blocks. The main motivation is that the full shared bias tensor [query_len, num_shared_blocks, block_size] can become prohibitively large (e.g., 19.53 GiB for 8k query length and 10k shared blocks), leading to out-of-memory errors.
Key changes:
- Chunked attention processing with online softmax merging to reduce memory footprint from ~19.53 GiB to ~128 MiB for large scenarios
- Dense bias generation approach that performs scatter operations on CPU and broadcasts on HPU, avoiding dynamic-length coordinate arrays
- New
SharedBlockChunkedBiasDataclass and_partial_attn_shared_chunked()function for chunk-wise processing
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| vllm_gaudi/v1/worker/hpu_model_runner.py | Updates unified config to check for both regular and chunked shared bias; removes extraneous blank line |
| vllm_gaudi/extension/unified_batch.py | Adds dense bias generator, implements chunked processing logic with configurable chunk size, and adds defensive checks for optional functions |
| vllm_gaudi/extension/unified.py | Implements core chunked attention logic with new helper functions, refactors shared attention to support both full and chunked modes, and optimizes tensor caching |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # With chunked dense generation, we only allocate (target_qlen, target_shared_blocks) for block_usages | ||
| # instead of the full (target_qlen, target_shared_blocks, block_size) bias tensor. | ||
| # Bias is generated per chunk: (target_qlen, chunk_size, block_size) | ||
| default_chunk_size = 32 # Process up to 64 blocks at a time for shared attention |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment states '64 blocks' but the variable is set to 32. The comment should match the actual value.
| default_chunk_size = 32 # Process up to 64 blocks at a time for shared attention | |
| default_chunk_size = 32 # Process up to 32 blocks at a time for shared attention |
| continue | ||
| num_spec_tokens = len([i for i in spec_tokens if i != -1]) | ||
| num_scheduled_tokens[idx] = num_spec_tokens + 1 | ||
| if scheduled_spec_decode_tokens is not None: |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditional check for scheduled_spec_decode_tokens has been added, but the indentation of the for-loop starting at line 580 was not updated. This changes the logic: the loop now only executes when scheduled_spec_decode_tokens is not None, whereas previously it always executed. Ensure the indentation is correct based on the intended behavior.
| if get_cumsum_and_arange is not None: | ||
| cu_num_tokens, _ = get_cumsum_and_arange(num_scheduled_tokens) | ||
| query_start_loc_np = query_start_loc_cpu.numpy() | ||
| query_start_loc_np[0] = 0 | ||
| query_start_loc_np[1:num_reqs + 1] = cu_num_tokens |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 630-632 access query_start_loc_np but this is only defined inside the conditional block. If get_cumsum_and_arange is None, these lines will fail with an undefined variable error. The indentation should be corrected so that lines 630-632 are inside the conditional block.
| During chunked attention, we slice block_usages[:, chunk_start:chunk_end] and | ||
| generate bias for each chunk on-the-fly. | ||
| """ | ||
| block_usages: torch.tensor # Dense: [num_query_tokens, num_shared_blocks] |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type hint uses lowercase torch.tensor which is not the correct type. Should be torch.Tensor (capitalized) for proper type hinting.
| def _partial_attn_shared_core(query: torch.tensor, | ||
| key: torch.tensor, | ||
| value: torch.tensor, | ||
| bias: torch.tensor, | ||
| fmin: torch.tensor, |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type hints use lowercase torch.tensor throughout the function signature. Should be torch.Tensor (capitalized) for consistency with Python type hinting conventions.
Signed-off-by: Konrad Zawora <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
Signed-off-by: Konrad Zawora <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
Signed-off-by: Konrad Zawora <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
This PR adds chunked processing for shared blocks in unified attention, mainly to deal with memory issues when you have a lot of shared blocks.
The problem was that the full shared bias tensor
[query_len, num_shared_blocks, block_size]can get huge with many shared blocks (e.g. for 8k query_len and 10k shared blocks, you'd have(8192*10000*128*2)/2^30=19.53 GiB!!!) and we were hitting OOMs.The fix is to process shared blocks in chunks:
chunk_sizeblocks (default 64), generates bias for just those blocks, computes partial attention, then uses flash-attention-style online softmax to merge results(8192*64*128*2)/2^20=128 MiBof memory needed for biases, rather than the 19.53 GiB you initially wouldmax_new = max(max_global, max_chunk)+ rescale danceDense bias approach
The other issue we've had here for a long time is shared bias generation. Previously we were passing variable-length coordinate arrays
[token_idx, block_idx, usage]to the bias generator. We first created per-block block bias with block usages, and used token and block coordinates to scatter it to a big[max_num_batched_tokens, max_num_shared_blocks, block_size]tensor. It was pretty horrible - not just because of slow scatters, but mostly because of dynamic coordinate dimension (max_num_shared_tokens) that can't be easily derived from bucketed dimensions.Now we use a "dense scatter on CPU, broadcast on HPU" approach:
[query_len, num_shared_blocks]tensor (any shape works on CPU)So, basically, now we're creating the bias tensor
(target_qlen, target_shared_blocks, block_size)by comparing & broadcasting relatively-reasonably-sized[query_len, num_shared_blocks]tensor with[num_shared_blocks]tensor, and the scatter operation goes to CPU, where dynamic indexing is (relatively) cheap. And we don't have to deal withmax_num_shared_tokensthat can range all the way from 2 to a bazillion, yay.I'm not sure if there's a reason not to use the dense biases across the board (even for non-chunked shared attn), so I've enabled it by default. I left
unified_attn_dense_shared_biasflag if anyone wants to disable it explicitly and use the old behavior.What's new
SharedBlockChunkedBiasData: holds dense block_usages for chunk-wise bias generation_partial_attn_shared_chunked(): chunked processing with online softmax merging_partial_attn_shared_core(): extracted inner loop for reuseHPUSharedBiasGeneratorDense: generates bias from dense block_usages via broadcastNotes
unified_attn_softmax_fa2. I was testing with it off (VLLM_UNIFIED_ATTN_SOFTMAX_FA2=false). It works functionally, but gets me to OOM much sooner than I'd expect it to.