Skip to content

feat: Integrate FA4 with custom BlockMask construction#96

Merged
yubofredwang merged 1 commit into
mainfrom
feat/flex-attention-fa4-blockmask
May 11, 2026
Merged

feat: Integrate FA4 with custom BlockMask construction#96
yubofredwang merged 1 commit into
mainfrom
feat/flex-attention-fa4-blockmask

Conversation

@yubofredwang
Copy link
Copy Markdown
Collaborator

Supports rectangular Q×KV block sizes (FA4 Blackwell uses Q_BS=256, KV_BS=128).
Requires KV_BS to divide Q_BS so each Q-block aligns to an integer number of
KV-blocks (the only configuration FA4 emits today).
Block classification:
Causal region (kv < Q_LEN):
* Strictly below diagonal (kj < qir) -> FULL (all True, skip mask_mod)
* Diagonal slab (kj in [qi
r, (qi+1)*r)) -> PARTIAL (lower-triangular)
* Above diagonal -> empty (omitted)
Suffix region (kv >= Q_LEN), one round per Q_LEN cols:
* Per round, Q-block qi has r PARTIAL blocks (one per BK-slot within Q-block).
* Suffix blocks are never FULL (mask is at most a thin diagonal).

@yubofredwang yubofredwang force-pushed the feat/flex-attention-fa4-blockmask branch from b56377d to 3e6e47e Compare May 11, 2026 22:53
…lockMask

Extends the analytical Eagle3 BlockMask builder (PR #91) to support
Flash Attention 4 (Blackwell SM100+) and emit FULL block tensors so the
attention kernel can fast-path strict-below-diagonal blocks.

Changes
-------
torchspec/models/ops/flex_attention.py
- _build_eagle3_block_mask_tensors now returns 8 tensors
  (kv_num, kv_idx, full_kv_num, full_kv_idx,
   q_num,  q_idx,  full_q_num,  full_q_idx) matching the BlockMask layout
  consumed by both flex_attention's create_block_mask and FA4's
  BlockSparseTensorsTorch.
- Accepts rectangular Q_BS x KV_BS block sizes (FA4 uses 256x128 on
  Blackwell, 128x128 on Hopper); BLOCK_SIZE arg accepts either an int
  (square) or a (Q_BS, KV_BS) tuple. Default stays 128 (square,
  backward-compatible with the existing FlexAttention path).
- Block classification:
    Causal region (kv < Q_LEN):
      - strictly below diagonal -> FULL (kernel skips mask_mod)
      - diagonal slab           -> PARTIAL (lower-triangular)
    Suffix region (kv >= Q_LEN): per round, r PARTIAL diagonal blocks per
      Q-block; suffix blocks are never FULL.
- compile_friendly_create_block_mask threads BLOCK_SIZE through to
  create_block_mask so the fallback path can also run with rectangular
  blocks.
- eagle3_block_mask dispatcher accepts the same int|tuple BLOCK_SIZE and
  preserves the analytical-vs-fallback decision.

tests/test_build_eagle3_block_mask.py
- test_full_and_partial_blocks_match_create_block_mask: parity for both
  PARTIAL and FULL kv/q tensors against create_block_mask ground truth.
- test_non_square_block_size_forward: end-to-end forward with
  Q_BS=256, KV_BS=128 matches the reference flex_attention output.
- test_dispatcher_with_non_square_block_size: tuple BLOCK_SIZE routes
  through the analytical path.

Performance
-----------
- On Hopper (square 128): mask-tensor memory grows ~3x (KB/MB scale,
  negligible) because kv_indices is padded to width n_kv as required by
  flex_attention/FA4. Attention kernel is unchanged or faster: per-Q-row
  PARTIAL count drops from (qi + n_rounds) to constant n_rounds, with
  the rest treated as FULL and skipping mask_mod.
- On Blackwell SM100+: this enables FA4 to run at all -- FA4's
  infer_block_sparse_expected_shapes asserts kv_indices.shape[-1] == n_kv,
  which the previous (compact) layout violated.

Correctness is verified against create_block_mask at dense, forward,
backward, GQA-broadcast, and PARTIAL/FULL block-tensor levels; the
torch.compile path is verified equal to eager and not to recompile across
TTT-step KV growth.

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
@yubofredwang yubofredwang force-pushed the feat/flex-attention-fa4-blockmask branch from 3e6e47e to d9916eb Compare May 11, 2026 22:55
@yubofredwang yubofredwang marked this pull request as ready for review May 11, 2026 23:04
@yubofredwang yubofredwang merged commit 5c865bd into main May 11, 2026
2 checks passed
@yubofredwang yubofredwang deleted the feat/flex-attention-fa4-blockmask branch May 11, 2026 23:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant