feat: Integrate FA4 with custom BlockMask construction#96
Merged
Conversation
b56377d to
3e6e47e
Compare
…lockMask Extends the analytical Eagle3 BlockMask builder (PR #91) to support Flash Attention 4 (Blackwell SM100+) and emit FULL block tensors so the attention kernel can fast-path strict-below-diagonal blocks. Changes ------- torchspec/models/ops/flex_attention.py - _build_eagle3_block_mask_tensors now returns 8 tensors (kv_num, kv_idx, full_kv_num, full_kv_idx, q_num, q_idx, full_q_num, full_q_idx) matching the BlockMask layout consumed by both flex_attention's create_block_mask and FA4's BlockSparseTensorsTorch. - Accepts rectangular Q_BS x KV_BS block sizes (FA4 uses 256x128 on Blackwell, 128x128 on Hopper); BLOCK_SIZE arg accepts either an int (square) or a (Q_BS, KV_BS) tuple. Default stays 128 (square, backward-compatible with the existing FlexAttention path). - Block classification: Causal region (kv < Q_LEN): - strictly below diagonal -> FULL (kernel skips mask_mod) - diagonal slab -> PARTIAL (lower-triangular) Suffix region (kv >= Q_LEN): per round, r PARTIAL diagonal blocks per Q-block; suffix blocks are never FULL. - compile_friendly_create_block_mask threads BLOCK_SIZE through to create_block_mask so the fallback path can also run with rectangular blocks. - eagle3_block_mask dispatcher accepts the same int|tuple BLOCK_SIZE and preserves the analytical-vs-fallback decision. tests/test_build_eagle3_block_mask.py - test_full_and_partial_blocks_match_create_block_mask: parity for both PARTIAL and FULL kv/q tensors against create_block_mask ground truth. - test_non_square_block_size_forward: end-to-end forward with Q_BS=256, KV_BS=128 matches the reference flex_attention output. - test_dispatcher_with_non_square_block_size: tuple BLOCK_SIZE routes through the analytical path. Performance ----------- - On Hopper (square 128): mask-tensor memory grows ~3x (KB/MB scale, negligible) because kv_indices is padded to width n_kv as required by flex_attention/FA4. Attention kernel is unchanged or faster: per-Q-row PARTIAL count drops from (qi + n_rounds) to constant n_rounds, with the rest treated as FULL and skipping mask_mod. - On Blackwell SM100+: this enables FA4 to run at all -- FA4's infer_block_sparse_expected_shapes asserts kv_indices.shape[-1] == n_kv, which the previous (compact) layout violated. Correctness is verified against create_block_mask at dense, forward, backward, GQA-broadcast, and PARTIAL/FULL block-tensor levels; the torch.compile path is verified equal to eager and not to recompile across TTT-step KV growth. Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
3e6e47e to
d9916eb
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Supports rectangular Q×KV block sizes (FA4 Blackwell uses Q_BS=256, KV_BS=128).
Requires KV_BS to divide Q_BS so each Q-block aligns to an integer number of
KV-blocks (the only configuration FA4 emits today).
Block classification:
Causal region (kv < Q_LEN):
* Strictly below diagonal (kj < qir) -> FULL (all True, skip mask_mod)
* Diagonal slab (kj in [qir, (qi+1)*r)) -> PARTIAL (lower-triangular)
* Above diagonal -> empty (omitted)
Suffix region (kv >= Q_LEN), one round per Q_LEN cols:
* Per round, Q-block qi has r PARTIAL blocks (one per BK-slot within Q-block).
* Suffix blocks are never FULL (mask is at most a thin diagonal).