Skip to content

Conversation

@ClarkChin08
Copy link

No description provided.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for causal masking in the flash attention implementation by introducing a new SubgroupLayoutQK template parameter and implementing the causal mask logic in the mainloop.

Key Changes:

  • Added SubgroupLayoutQK template parameter to the collective mainloop and kernel interfaces
  • Implemented causal masking logic that applies -INFINITY to attention scores beyond the causal boundary
  • Updated the example runner to conditionally instantiate causal or non-causal configurations based on user options

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp Implements causal mask logic and removes the static assertion that previously blocked causal mask usage
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp Adds subgroup layout type alias and computes sequence coordinates for causal masking
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp Adds SubgroupLayoutQK template parameter to mainloop type
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp Conditionally selects causal or non-causal kernel based on is_causal option

@ClarkChin08 ClarkChin08 force-pushed the fa_causal_mask branch 2 times, most recently from bb07ccc to 836f2c4 Compare November 10, 2025 08:14
Signed-off-by: Chen, Xi2 <[email protected]>
Signed-off-by: Chen, Xi2 <[email protected]>
Signed-off-by: Chen, Xi2 <[email protected]>
@tdeng5 tdeng5 merged commit fb8c97c into intel:main Nov 12, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants