-
Notifications
You must be signed in to change notification settings - Fork 101
Draft: Add FlashAttention online merge in Unified Attention #785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Konrad Zawora <[email protected]>
Signed-off-by: Konrad Zawora <[email protected]>
Signed-off-by: Konrad Zawora <[email protected]>
Signed-off-by: Konrad Zawora <[email protected]>
Signed-off-by: Konrad Zawora <[email protected]>
Signed-off-by: Konrad Zawora <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements FlashAttention online merge in Unified Attention to reduce memory consumption by performing rescaling incrementally rather than after computing all attention parts. The changes introduce chunked processing for shared blocks with online bias generation, avoiding materialization of large bias tensors.
Key changes:
- Added online merge algorithm that incrementally combines attention results using flash-attention style rescaling
- Implemented chunked processing for shared blocks with per-chunk bias generation from dense block_usages
- Introduced dense bias generation path that scatters on CPU and broadcasts on HPU to avoid dynamic shapes
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| vllm_gaudi/v1/worker/hpu_model_runner.py | Updated unified config to detect both shared_bias and shared_bias_chunked, removed blank line |
| vllm_gaudi/extension/unified_batch.py | Added dense bias generation, chunked processing logic, and new SharedBiasGeneratorDense class |
| vllm_gaudi/extension/unified.py | Implemented online merge algorithm, chunked shared attention computation, and updated entry points |
| vllm_gaudi/extension/features.py | Added unified_attn_dense_shared_bias feature flag |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if get_cumsum_and_arange is not None: | ||
| cu_num_tokens, _ = get_cumsum_and_arange(num_scheduled_tokens) |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition check for get_cumsum_and_arange should be performed before calling it. However, this introduces a problem: if get_cumsum_and_arange is None, then cu_num_tokens remains undefined, but lines 746-747 attempt to use it. This will cause a NameError at runtime when get_cumsum_and_arange is None.
| if scaled_query_latent is not None: | ||
| shared = partial_attn_shared(query=scaled_query_latent, |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditional check for scaled_query_latent was moved to wrap the entire partial_attn_shared call rather than using a ternary expression. While this is more readable, it creates code duplication with the else clause at lines 875-876 that sets shared to (None, None, None). Consider extracting this pattern into a helper or maintaining the ternary for consistency with the causal and unique attention calls.
| return query_latent.flatten(-2, -1) # [tokens, num_heads * head_dim] | ||
| if use_online_merge: | ||
| acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *unique) | ||
| if use_online_merge: |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two consecutive 'if use_online_merge:' checks at lines 885 and 887. These should be combined into a single conditional block to improve readability and avoid redundant checks.
| if use_online_merge: |
Signed-off-by: Konrad Zawora <[email protected]>
Signed-off-by: Konrad Zawora <[email protected]>
Signed-off-by: Konrad Zawora <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
Signed-off-by: Konrad Zawora <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
Further experiments on top of #784 - I wanted to check if we can avoid some OOMs by performing FlashAttention rescaling online rather than after computing all the parts - should save us memory on some intermediate buffers. Accuracy is surprisingly okay-ish, but I haven't tested this too thouroughly.