feat: analytical Eagle3 BlockMask builder (O(num_blocks) memory)#91
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 8a8981fbde
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def causal_mask(b, h, q_idx, kv_idx): | ||
| # Causal will keep shrinking by 1 diagnol due to appended suffix | ||
| # Shirnk the causal by diagnol | ||
| causal_mask = q_idx >= kv_idx | ||
| padding_mask = (kv_idx < seq_lengths[b]) & (q_idx < seq_lengths[b]) | ||
| return causal_mask & padding_mask | ||
| return q_idx >= kv_idx | ||
|
|
||
| def suffix_mask(b, h, q_idx, kv_idx): | ||
| suffix_mask = kv_idx >= Q_LEN | ||
| padding_mask = kv_idx % Q_LEN < seq_lengths[b] | ||
| diagnol_mask = (kv_idx - q_idx) % Q_LEN == 0 | ||
| return suffix_mask & padding_mask & diagnol_mask | ||
| return (kv_idx >= Q_LEN) & ((kv_idx - q_idx) % Q_LEN == 0) |
There was a problem hiding this comment.
Restore per-sequence padding checks in Eagle3 mask
The new generate_eagle3_mask ignores seq_lengths, so both the analytical dispatcher fallback and any direct callers now build a mask as if every sample were full length. In padded batches (attention_mask has zeros), this removes the per-example q_idx < seq_lengths[b] / kv_idx < seq_lengths[b] filtering that the previous version used, so padded rows/columns are treated as valid attention positions. This changes training semantics for variable-length batches and can inject pad-token attention/gradients where masking was previously enforced.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
This is useful, but we have not supported padded batch yet. so... until we support that, this is not a concern
There was a problem hiding this comment.
Pull request overview
This PR introduces an analytical constructor for the Eagle3 causal+suffix BlockMask to avoid the extreme peak memory costs of create_block_mask on long-context Eagle3 training runs, and wires the new dispatcher into the flex-attention draft model implementations.
Changes:
- Add
build_eagle3_block_maskto directly constructkv_indices/q_indicesin O(num_blocks) memory, pluseagle3_block_maskto dispatch to the analytical path when shapes are aligned. - Switch
LlamaFlexAttentionandDeepSeekMLAFlexAttentionto useeagle3_block_maskinstead of building masks viacreate_block_mask. - Re-export the new APIs from
torchspec.models.opsand add targeted unit tests for analytical-vs-reference parity.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| torchspec/models/ops/flex_attention.py | Adds the analytical Eagle3 BlockMask builder + dispatcher; modifies generate_eagle3_mask. |
| torchspec/models/ops/init.py | Re-exports build_eagle3_block_mask and eagle3_block_mask. |
| torchspec/models/draft/llama3_eagle.py | Adopts the dispatcher in LlamaFlexAttention. |
| torchspec/models/draft/deepseek_eagle.py | Adopts the dispatcher in DeepSeekMLAFlexAttention. |
| tests/test_build_eagle3_block_mask.py | Adds unit tests for mask structure, flex_attention parity, gradients, and dispatcher routing. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Add `build_eagle3_block_mask`, an O(num_blocks)-memory direct constructor for the Eagle3 causal+suffix BlockMask, plus an `eagle3_block_mask` dispatcher that falls back to `create_block_mask` when shape preconditions don't hold (Q_LEN/KV_LEN BLOCK_SIZE alignment, KV_LEN multiple of Q_LEN). Motivation: `create_block_mask` materialises the full (Q_LEN, KV_LEN) boolean grid internally, costing ~112 GB at Q=49K, KV=245K and OOM-ing on long-context Eagle3 training. The analytical builder writes the sparse `kv_indices` / `q_indices` tensors directly from the known mask structure, reducing peak memory to a few MB. Also simplifies `generate_eagle3_mask` to drop the seq_lengths-aware padding mask (now only used by the fallback path; padding is enforced upstream via the attention mask). Adopts the dispatcher in `LlamaFlexAttention` and `DeepSeekMLAFlexAttention`. Adoption in MoE draft models will land in a follow-up. Tests: tests/test_build_eagle3_block_mask.py covers element-level equivalence vs `create_block_mask` reference, GQA, forward/backward parity, dispatcher fallback paths, and a memory-bound assertion (<10 MB at Q=4096, KV=20480). Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
8a8981f to
2b242a9
Compare
8776373 to
3424318
Compare
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
3424318 to
78eba5c
Compare
…lockMask Extends the analytical Eagle3 BlockMask builder (PR #91) to support Flash Attention 4 (Blackwell SM100+) and emit FULL block tensors so the attention kernel can fast-path strict-below-diagonal blocks. Changes ------- torchspec/models/ops/flex_attention.py - _build_eagle3_block_mask_tensors now returns 8 tensors (kv_num, kv_idx, full_kv_num, full_kv_idx, q_num, q_idx, full_q_num, full_q_idx) matching the BlockMask layout consumed by both flex_attention's create_block_mask and FA4's BlockSparseTensorsTorch. - Accepts rectangular Q_BS x KV_BS block sizes (FA4 uses 256x128 on Blackwell, 128x128 on Hopper); BLOCK_SIZE arg accepts either an int (square) or a (Q_BS, KV_BS) tuple. Default stays 128 (square, backward-compatible with the existing FlexAttention path). - Block classification: Causal region (kv < Q_LEN): - strictly below diagonal -> FULL (kernel skips mask_mod) - diagonal slab -> PARTIAL (lower-triangular) Suffix region (kv >= Q_LEN): per round, r PARTIAL diagonal blocks per Q-block; suffix blocks are never FULL. - compile_friendly_create_block_mask threads BLOCK_SIZE through to create_block_mask so the fallback path can also run with rectangular blocks. - eagle3_block_mask dispatcher accepts the same int|tuple BLOCK_SIZE and preserves the analytical-vs-fallback decision. tests/test_build_eagle3_block_mask.py - test_full_and_partial_blocks_match_create_block_mask: parity for both PARTIAL and FULL kv/q tensors against create_block_mask ground truth. - test_non_square_block_size_forward: end-to-end forward with Q_BS=256, KV_BS=128 matches the reference flex_attention output. - test_dispatcher_with_non_square_block_size: tuple BLOCK_SIZE routes through the analytical path. Performance ----------- - On Hopper (square 128): mask-tensor memory grows ~3x (KB/MB scale, negligible) because kv_indices is padded to width n_kv as required by flex_attention/FA4. Attention kernel is unchanged or faster: per-Q-row PARTIAL count drops from (qi + n_rounds) to constant n_rounds, with the rest treated as FULL and skipping mask_mod. - On Blackwell SM100+: this enables FA4 to run at all -- FA4's infer_block_sparse_expected_shapes asserts kv_indices.shape[-1] == n_kv, which the previous (compact) layout violated. Correctness is verified against create_block_mask at dense, forward, backward, GQA-broadcast, and PARTIAL/FULL block-tensor levels; the torch.compile path is verified equal to eager and not to recompile across TTT-step KV growth. Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
…lockMask Extends the analytical Eagle3 BlockMask builder (PR #91) to support Flash Attention 4 (Blackwell SM100+) and emit FULL block tensors so the attention kernel can fast-path strict-below-diagonal blocks. Changes ------- torchspec/models/ops/flex_attention.py - _build_eagle3_block_mask_tensors now returns 8 tensors (kv_num, kv_idx, full_kv_num, full_kv_idx, q_num, q_idx, full_q_num, full_q_idx) matching the BlockMask layout consumed by both flex_attention's create_block_mask and FA4's BlockSparseTensorsTorch. - Accepts rectangular Q_BS x KV_BS block sizes (FA4 uses 256x128 on Blackwell, 128x128 on Hopper); BLOCK_SIZE arg accepts either an int (square) or a (Q_BS, KV_BS) tuple. Default stays 128 (square, backward-compatible with the existing FlexAttention path). - Block classification: Causal region (kv < Q_LEN): - strictly below diagonal -> FULL (kernel skips mask_mod) - diagonal slab -> PARTIAL (lower-triangular) Suffix region (kv >= Q_LEN): per round, r PARTIAL diagonal blocks per Q-block; suffix blocks are never FULL. - compile_friendly_create_block_mask threads BLOCK_SIZE through to create_block_mask so the fallback path can also run with rectangular blocks. - eagle3_block_mask dispatcher accepts the same int|tuple BLOCK_SIZE and preserves the analytical-vs-fallback decision. tests/test_build_eagle3_block_mask.py - test_full_and_partial_blocks_match_create_block_mask: parity for both PARTIAL and FULL kv/q tensors against create_block_mask ground truth. - test_non_square_block_size_forward: end-to-end forward with Q_BS=256, KV_BS=128 matches the reference flex_attention output. - test_dispatcher_with_non_square_block_size: tuple BLOCK_SIZE routes through the analytical path. Performance ----------- - On Hopper (square 128): mask-tensor memory grows ~3x (KB/MB scale, negligible) because kv_indices is padded to width n_kv as required by flex_attention/FA4. Attention kernel is unchanged or faster: per-Q-row PARTIAL count drops from (qi + n_rounds) to constant n_rounds, with the rest treated as FULL and skipping mask_mod. - On Blackwell SM100+: this enables FA4 to run at all -- FA4's infer_block_sparse_expected_shapes asserts kv_indices.shape[-1] == n_kv, which the previous (compact) layout violated. Correctness is verified against create_block_mask at dense, forward, backward, GQA-broadcast, and PARTIAL/FULL block-tensor levels; the torch.compile path is verified equal to eager and not to recompile across TTT-step KV growth. Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Summary
build_eagle3_block_mask, an O(num_blocks)-memory direct constructor for the Eagle3 causal+suffixBlockMask.create_block_maskmaterialises the full(Q_LEN, KV_LEN)boolean grid internally (~112 GB at Q=49K, KV=245K), causing OOM on long-context Eagle3 training. The analytical builder writes the sparsekv_indices/q_indicestensors directly from the known mask structure, dropping peak memory to a few MB.eagle3_block_maskdispatcher: uses the analytical builder when the preconditions hold (Q_LEN % BLOCK_SIZE == 0,KV_LEN % BLOCK_SIZE == 0,KV_LEN % Q_LEN == 0), otherwise falls back tocreate_block_mask/compile_friendly_create_block_mask(covers tests/edge cases where the O(Q*KV) cost is irrelevant).generate_eagle3_maskto drop theseq_lengths-aware padding mask. It's now only consumed by the dispatcher's fallback path; padding is enforced upstream via the attention mask, and the analytical path gets exact equivalence by construction (verified by tests below).build_eagle3_block_maskandeagle3_block_maskfromtorchspec.models.ops.Testing
New file:
tests/test_build_eagle3_block_mask.py(18 tests + 13 subtests):Local run on torch 2.9.1+cu129, single CUDA device:
Smoke import check confirms
LlamaFlexAttentionandDeepSeekMLAFlexAttentionstill load with the dispatcher import. No new lints.Solves #86