-
Notifications
You must be signed in to change notification settings - Fork 68
Add CausalMask support with new flash attention api #604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Outdated
Show resolved
Hide resolved
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for causal masking in the flash attention implementation by introducing a new SubgroupLayoutQK template parameter and implementing the causal mask logic in the mainloop.
Key Changes:
- Added
SubgroupLayoutQKtemplate parameter to the collective mainloop and kernel interfaces - Implemented causal masking logic that applies
-INFINITYto attention scores beyond the causal boundary - Updated the example runner to conditionally instantiate causal or non-causal configurations based on user options
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp |
Implements causal mask logic and removes the static assertion that previously blocked causal mask usage |
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp |
Adds subgroup layout type alias and computes sequence coordinates for causal masking |
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp |
Adds SubgroupLayoutQK template parameter to mainloop type |
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp |
Conditionally selects causal or non-causal kernel based on is_causal option |
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Chen, Xi2 <[email protected]>
Signed-off-by: Chen, Xi2 <[email protected]>
bb07ccc to
836f2c4
Compare
Signed-off-by: Chen, Xi2 <[email protected]>
836f2c4 to
21a1bce
Compare
Signed-off-by: Chen, Xi2 <[email protected]>
Signed-off-by: Chen, Xi2 <[email protected]>
No description provided.