Skip to content

feat: analytical Eagle3 BlockMask builder (O(num_blocks) memory)#91

Merged
yubofredwang merged 2 commits into
mainfrom
feat/eagle3-block-mask
May 10, 2026
Merged

feat: analytical Eagle3 BlockMask builder (O(num_blocks) memory)#91
yubofredwang merged 2 commits into
mainfrom
feat/eagle3-block-mask

Conversation

@yubofredwang
Copy link
Copy Markdown
Collaborator

@yubofredwang yubofredwang commented May 5, 2026

Summary

  • Add build_eagle3_block_mask, an O(num_blocks)-memory direct constructor for the Eagle3 causal+suffix BlockMask. create_block_mask materialises the full (Q_LEN, KV_LEN) boolean grid internally (~112 GB at Q=49K, KV=245K), causing OOM on long-context Eagle3 training. The analytical builder writes the sparse kv_indices/q_indices tensors directly from the known mask structure, dropping peak memory to a few MB.
  • Add eagle3_block_mask dispatcher: uses the analytical builder when the preconditions hold (Q_LEN % BLOCK_SIZE == 0, KV_LEN % BLOCK_SIZE == 0, KV_LEN % Q_LEN == 0), otherwise falls back to create_block_mask / compile_friendly_create_block_mask (covers tests/edge cases where the O(Q*KV) cost is irrelevant).
  • Simplify generate_eagle3_mask to drop the seq_lengths-aware padding mask. It's now only consumed by the dispatcher's fallback path; padding is enforced upstream via the attention mask, and the analytical path gets exact equivalence by construction (verified by tests below).
  • Re-export build_eagle3_block_mask and eagle3_block_mask from torchspec.models.ops.

Testing

New file: tests/test_build_eagle3_block_mask.py (18 tests + 13 subtests):

Local run on torch 2.9.1+cu129, single CUDA device:

$ pytest tests/test_build_eagle3_block_mask.py -x -q
..................                                          [100%]
18 passed, 1 warning, 13 subtests passed in 7.52s

Smoke import check confirms LlamaFlexAttention and DeepSeekMLAFlexAttention still load with the dispatcher import. No new lints.

Solves #86

Copilot AI review requested due to automatic review settings May 5, 2026 23:02
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8a8981fbde

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 121 to +125
def causal_mask(b, h, q_idx, kv_idx):
# Causal will keep shrinking by 1 diagnol due to appended suffix
# Shirnk the causal by diagnol
causal_mask = q_idx >= kv_idx
padding_mask = (kv_idx < seq_lengths[b]) & (q_idx < seq_lengths[b])
return causal_mask & padding_mask
return q_idx >= kv_idx

def suffix_mask(b, h, q_idx, kv_idx):
suffix_mask = kv_idx >= Q_LEN
padding_mask = kv_idx % Q_LEN < seq_lengths[b]
diagnol_mask = (kv_idx - q_idx) % Q_LEN == 0
return suffix_mask & padding_mask & diagnol_mask
return (kv_idx >= Q_LEN) & ((kv_idx - q_idx) % Q_LEN == 0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore per-sequence padding checks in Eagle3 mask

The new generate_eagle3_mask ignores seq_lengths, so both the analytical dispatcher fallback and any direct callers now build a mask as if every sample were full length. In padded batches (attention_mask has zeros), this removes the per-example q_idx < seq_lengths[b] / kv_idx < seq_lengths[b] filtering that the previous version used, so padded rows/columns are treated as valid attention positions. This changes training semantics for variable-length batches and can inject pad-token attention/gradients where masking was previously enforced.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useful, but we have not supported padded batch yet. so... until we support that, this is not a concern

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an analytical constructor for the Eagle3 causal+suffix BlockMask to avoid the extreme peak memory costs of create_block_mask on long-context Eagle3 training runs, and wires the new dispatcher into the flex-attention draft model implementations.

Changes:

  • Add build_eagle3_block_mask to directly construct kv_indices/q_indices in O(num_blocks) memory, plus eagle3_block_mask to dispatch to the analytical path when shapes are aligned.
  • Switch LlamaFlexAttention and DeepSeekMLAFlexAttention to use eagle3_block_mask instead of building masks via create_block_mask.
  • Re-export the new APIs from torchspec.models.ops and add targeted unit tests for analytical-vs-reference parity.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
torchspec/models/ops/flex_attention.py Adds the analytical Eagle3 BlockMask builder + dispatcher; modifies generate_eagle3_mask.
torchspec/models/ops/init.py Re-exports build_eagle3_block_mask and eagle3_block_mask.
torchspec/models/draft/llama3_eagle.py Adopts the dispatcher in LlamaFlexAttention.
torchspec/models/draft/deepseek_eagle.py Adopts the dispatcher in DeepSeekMLAFlexAttention.
tests/test_build_eagle3_block_mask.py Adds unit tests for mask structure, flex_attention parity, gradients, and dispatcher routing.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread torchspec/models/ops/flex_attention.py Outdated
Comment thread torchspec/models/ops/flex_attention.py
Comment thread torchspec/models/ops/flex_attention.py Outdated
Comment thread torchspec/models/draft/llama3_eagle.py
Comment thread torchspec/models/draft/deepseek_eagle.py Outdated
@yubofredwang yubofredwang marked this pull request as draft May 5, 2026 23:13
Add `build_eagle3_block_mask`, an O(num_blocks)-memory direct constructor
for the Eagle3 causal+suffix BlockMask, plus an `eagle3_block_mask`
dispatcher that falls back to `create_block_mask` when shape preconditions
don't hold (Q_LEN/KV_LEN BLOCK_SIZE alignment, KV_LEN multiple of Q_LEN).

Motivation: `create_block_mask` materialises the full (Q_LEN, KV_LEN)
boolean grid internally, costing ~112 GB at Q=49K, KV=245K and OOM-ing
on long-context Eagle3 training. The analytical builder writes the
sparse `kv_indices` / `q_indices` tensors directly from the known mask
structure, reducing peak memory to a few MB.

Also simplifies `generate_eagle3_mask` to drop the seq_lengths-aware
padding mask (now only used by the fallback path; padding is enforced
upstream via the attention mask).

Adopts the dispatcher in `LlamaFlexAttention` and `DeepSeekMLAFlexAttention`.
Adoption in MoE draft models will land in a follow-up.

Tests: tests/test_build_eagle3_block_mask.py covers element-level
equivalence vs `create_block_mask` reference, GQA, forward/backward
parity, dispatcher fallback paths, and a memory-bound assertion
(<10 MB at Q=4096, KV=20480).

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
@yubofredwang yubofredwang force-pushed the feat/eagle3-block-mask branch from 8a8981f to 2b242a9 Compare May 6, 2026 00:08
@yubofredwang yubofredwang marked this pull request as ready for review May 10, 2026 08:15
@yubofredwang yubofredwang force-pushed the feat/eagle3-block-mask branch from 8776373 to 3424318 Compare May 10, 2026 08:15
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
@yubofredwang yubofredwang force-pushed the feat/eagle3-block-mask branch from 3424318 to 78eba5c Compare May 10, 2026 08:46
@yubofredwang yubofredwang merged commit e63cfab into main May 10, 2026
2 checks passed
@yubofredwang yubofredwang deleted the feat/eagle3-block-mask branch May 10, 2026 08:47
yubofredwang added a commit that referenced this pull request May 11, 2026
…lockMask

Extends the analytical Eagle3 BlockMask builder (PR #91) to support
Flash Attention 4 (Blackwell SM100+) and emit FULL block tensors so the
attention kernel can fast-path strict-below-diagonal blocks.

Changes
-------
torchspec/models/ops/flex_attention.py
- _build_eagle3_block_mask_tensors now returns 8 tensors
  (kv_num, kv_idx, full_kv_num, full_kv_idx,
   q_num,  q_idx,  full_q_num,  full_q_idx) matching the BlockMask layout
  consumed by both flex_attention's create_block_mask and FA4's
  BlockSparseTensorsTorch.
- Accepts rectangular Q_BS x KV_BS block sizes (FA4 uses 256x128 on
  Blackwell, 128x128 on Hopper); BLOCK_SIZE arg accepts either an int
  (square) or a (Q_BS, KV_BS) tuple. Default stays 128 (square,
  backward-compatible with the existing FlexAttention path).
- Block classification:
    Causal region (kv < Q_LEN):
      - strictly below diagonal -> FULL (kernel skips mask_mod)
      - diagonal slab           -> PARTIAL (lower-triangular)
    Suffix region (kv >= Q_LEN): per round, r PARTIAL diagonal blocks per
      Q-block; suffix blocks are never FULL.
- compile_friendly_create_block_mask threads BLOCK_SIZE through to
  create_block_mask so the fallback path can also run with rectangular
  blocks.
- eagle3_block_mask dispatcher accepts the same int|tuple BLOCK_SIZE and
  preserves the analytical-vs-fallback decision.

tests/test_build_eagle3_block_mask.py
- test_full_and_partial_blocks_match_create_block_mask: parity for both
  PARTIAL and FULL kv/q tensors against create_block_mask ground truth.
- test_non_square_block_size_forward: end-to-end forward with
  Q_BS=256, KV_BS=128 matches the reference flex_attention output.
- test_dispatcher_with_non_square_block_size: tuple BLOCK_SIZE routes
  through the analytical path.

Performance
-----------
- On Hopper (square 128): mask-tensor memory grows ~3x (KB/MB scale,
  negligible) because kv_indices is padded to width n_kv as required by
  flex_attention/FA4. Attention kernel is unchanged or faster: per-Q-row
  PARTIAL count drops from (qi + n_rounds) to constant n_rounds, with
  the rest treated as FULL and skipping mask_mod.
- On Blackwell SM100+: this enables FA4 to run at all -- FA4's
  infer_block_sparse_expected_shapes asserts kv_indices.shape[-1] == n_kv,
  which the previous (compact) layout violated.

Correctness is verified against create_block_mask at dense, forward,
backward, GQA-broadcast, and PARTIAL/FULL block-tensor levels; the
torch.compile path is verified equal to eager and not to recompile across
TTT-step KV growth.

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
yubofredwang added a commit that referenced this pull request May 11, 2026
…lockMask

Extends the analytical Eagle3 BlockMask builder (PR #91) to support
Flash Attention 4 (Blackwell SM100+) and emit FULL block tensors so the
attention kernel can fast-path strict-below-diagonal blocks.

Changes
-------
torchspec/models/ops/flex_attention.py
- _build_eagle3_block_mask_tensors now returns 8 tensors
  (kv_num, kv_idx, full_kv_num, full_kv_idx,
   q_num,  q_idx,  full_q_num,  full_q_idx) matching the BlockMask layout
  consumed by both flex_attention's create_block_mask and FA4's
  BlockSparseTensorsTorch.
- Accepts rectangular Q_BS x KV_BS block sizes (FA4 uses 256x128 on
  Blackwell, 128x128 on Hopper); BLOCK_SIZE arg accepts either an int
  (square) or a (Q_BS, KV_BS) tuple. Default stays 128 (square,
  backward-compatible with the existing FlexAttention path).
- Block classification:
    Causal region (kv < Q_LEN):
      - strictly below diagonal -> FULL (kernel skips mask_mod)
      - diagonal slab           -> PARTIAL (lower-triangular)
    Suffix region (kv >= Q_LEN): per round, r PARTIAL diagonal blocks per
      Q-block; suffix blocks are never FULL.
- compile_friendly_create_block_mask threads BLOCK_SIZE through to
  create_block_mask so the fallback path can also run with rectangular
  blocks.
- eagle3_block_mask dispatcher accepts the same int|tuple BLOCK_SIZE and
  preserves the analytical-vs-fallback decision.

tests/test_build_eagle3_block_mask.py
- test_full_and_partial_blocks_match_create_block_mask: parity for both
  PARTIAL and FULL kv/q tensors against create_block_mask ground truth.
- test_non_square_block_size_forward: end-to-end forward with
  Q_BS=256, KV_BS=128 matches the reference flex_attention output.
- test_dispatcher_with_non_square_block_size: tuple BLOCK_SIZE routes
  through the analytical path.

Performance
-----------
- On Hopper (square 128): mask-tensor memory grows ~3x (KB/MB scale,
  negligible) because kv_indices is padded to width n_kv as required by
  flex_attention/FA4. Attention kernel is unchanged or faster: per-Q-row
  PARTIAL count drops from (qi + n_rounds) to constant n_rounds, with
  the rest treated as FULL and skipping mask_mod.
- On Blackwell SM100+: this enables FA4 to run at all -- FA4's
  infer_block_sparse_expected_shapes asserts kv_indices.shape[-1] == n_kv,
  which the previous (compact) layout violated.

Correctness is verified against create_block_mask at dense, forward,
backward, GQA-broadcast, and PARTIAL/FULL block-tensor levels; the
torch.compile path is verified equal to eager and not to recompile across
TTT-step KV growth.

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants