Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions tests/test_build_eagle3_block_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Tests for build_eagle3_block_mask -- the analytical Eagle3 BlockMask builder."""

import unittest

import torch
import torch._dynamo as dynamo
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from torchspec.models.ops.flex_attention import (
_build_eagle3_block_mask_tensors,
build_eagle3_block_mask,
eagle3_block_mask,
generate_eagle3_mask,
)

DEVICE = "cuda"
BLOCK_SIZE = 128


def dense_from_mod(Q_LEN, KV_LEN, mask_mod, batch_idx=0):
"""Materialise a (Q_LEN, KV_LEN) bool grid from a mask_mod or BlockMask."""
qi = torch.arange(Q_LEN, device=DEVICE).unsqueeze(1)
ki = torch.arange(KV_LEN, device=DEVICE).unsqueeze(0)
b = torch.full_like(qi, batch_idx)
h = torch.zeros_like(qi)
fn = mask_mod.mask_mod if hasattr(mask_mod, "mask_mod") else mask_mod
return fn(b, h, qi, ki).bool()


def reference_block_mask(Q_LEN, KV_LEN, B=1, H=1):
"""create_block_mask using the production simplified mask_mod."""
return create_block_mask(
generate_eagle3_mask(Q_LEN, KV_LEN),
B=B,
H=H,
Q_LEN=Q_LEN,
KV_LEN=KV_LEN,
device=DEVICE,
)


# Sizes covering single round, short-multi-round, and aligned-multi-round cases.
SHAPES = [(256, 256), (256, 768), (256, 1280), (1024, 4096)]


class TestBuildEagle3BlockMask(unittest.TestCase):
"""Analytical builder must produce a mask equivalent to create_block_mask."""

def test_dense_mask_matches_reference(self):
for Q, KV in SHAPES:
with self.subTest(Q=Q, KV=KV):
ref = dense_from_mod(Q, KV, reference_block_mask(Q, KV))
ours = dense_from_mod(Q, KV, build_eagle3_block_mask(Q, KV, device=DEVICE))
self.assertTrue(torch.equal(ref, ours))

def test_forward_matches_reference(self):
torch.manual_seed(42)
B, H, D = 1, 4, 64
for Q, KV in SHAPES:
with self.subTest(Q=Q, KV=KV):
q = torch.randn(B, H, Q, D, device=DEVICE, dtype=torch.bfloat16)
k = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16)
v = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16)
ref = flex_attention(q, k, v, block_mask=reference_block_mask(Q, KV))
ours = flex_attention(q, k, v, block_mask=build_eagle3_block_mask(Q, KV, B=B))
self.assertEqual(ref.shape, ours.shape)
self.assertFalse(ours.isnan().any())
self.assertLess((ref - ours).abs().max().item(), 1e-5)

def test_backward_gradients_match_reference(self):
torch.manual_seed(42)
B, H, D, Q, KV = 1, 4, 64, 256, 768

def grads(mask):
q = torch.randn(B, H, Q, D, device=DEVICE, dtype=torch.float32, requires_grad=True)
k = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.float32, requires_grad=True)
v = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.float32, requires_grad=True)
flex_attention(q, k, v, block_mask=mask).sum().backward()
return q.grad, k.grad, v.grad

torch.manual_seed(42)
gq_r, gk_r, gv_r = grads(reference_block_mask(Q, KV))
torch.manual_seed(42)
gq_o, gk_o, gv_o = grads(build_eagle3_block_mask(Q, KV, B=B))
for name, gr, go in [("q", gq_r, gq_o), ("k", gk_r, gk_o), ("v", gv_r, gv_o)]:
self.assertLess((gr - go).abs().max().item(), 1e-4, f"grad mismatch on {name}")

def test_gqa_broadcast(self):
"""H=1 mask broadcasts over multi-Q-head GQA without NaN."""
torch.manual_seed(0)
B, Qh, KVh, D, Q, KV = 1, 8, 2, 64, 256, 768
q = torch.randn(B, Qh, Q, D, device=DEVICE, dtype=torch.bfloat16)
k = torch.randn(B, KVh, KV, D, device=DEVICE, dtype=torch.bfloat16)
v = torch.randn(B, KVh, KV, D, device=DEVICE, dtype=torch.bfloat16)
bm = build_eagle3_block_mask(Q, KV, B=B, device=DEVICE)
out = flex_attention(q, k, v, block_mask=bm, enable_gqa=True)
self.assertEqual(out.shape, (B, Qh, Q, D))
self.assertFalse(out.isnan().any())

def test_memory_is_negligible(self):
"""Original create_block_mask costs ~112 GB at Q=49K; this must stay in MB."""
Q, KV = 4096, 4096 * 5
torch.cuda.reset_peak_memory_stats()
before = torch.cuda.memory_allocated()
build_eagle3_block_mask(Q, KV, device=DEVICE)
mem_mb = (torch.cuda.max_memory_allocated() - before) / 1024**2
self.assertLess(mem_mb, 10.0, f"used {mem_mb:.1f} MB")

def test_assertions_on_invalid_shapes(self):
# not divisible by BLOCK_SIZE
with self.assertRaises(AssertionError):
build_eagle3_block_mask(100, 300, device=DEVICE)
# KV not a Q-multiple
with self.assertRaises(AssertionError):
build_eagle3_block_mask(256, 384, device=DEVICE)


class TestEagle3BlockMaskDispatcher(unittest.TestCase):
"""Dispatcher picks analytical when shapes align, otherwise falls back."""

def test_analytical_path_when_aligned(self):
for Q, KV in [(256, 256), (256, 768)]:
with self.subTest(Q=Q, KV=KV):
disp = eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE)
ana = build_eagle3_block_mask(Q, KV, device=DEVICE)
self.assertTrue(torch.equal(disp.kv_indices, ana.kv_indices))
self.assertTrue(torch.equal(disp.q_indices, ana.q_indices))

def test_fallback_path_matches_reference_mask_mod(self):
"""Fallback shapes (Q<BLOCK_SIZE, or KV%Q!=0) must produce the canonical mask."""
for Q, KV in [(64, 64), (256, 384)]:
with self.subTest(Q=Q, KV=KV):
bm = eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE)
expected = dense_from_mod(Q, KV, generate_eagle3_mask(Q, KV))
self.assertTrue(torch.equal(dense_from_mod(Q, KV, bm), expected))

def test_dispatcher_forward_matches_reference(self):
torch.manual_seed(0)
B, H, D = 1, 4, 64
for Q, KV in [(256, 256), (256, 768)]:
with self.subTest(Q=Q, KV=KV):
q = torch.randn(B, H, Q, D, device=DEVICE, dtype=torch.bfloat16)
k = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16)
v = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16)
ref = flex_attention(q, k, v, block_mask=reference_block_mask(Q, KV, B=B))
disp = flex_attention(
q, k, v, block_mask=eagle3_block_mask(Q, KV, B=B, H=1, device=DEVICE)
)
self.assertLess((ref - disp).abs().max().item(), 1e-5)


class TestCompiledTensorBuilder(unittest.TestCase):
"""build_eagle3_block_mask routes through torch.compile -- verify behaviour."""

def test_compiled_output_matches_eager(self):
for Q, KV in [(256, 256), (256, 768), (1024, 4096)]:
with self.subTest(Q=Q, KV=KV):
eager = _build_eagle3_block_mask_tensors(Q, KV, 1, 1, BLOCK_SIZE, DEVICE)
bm = build_eagle3_block_mask(Q, KV, device=DEVICE)
self.assertTrue(torch.equal(bm.kv_num_blocks, eager[0]))
self.assertTrue(torch.equal(bm.kv_indices, eager[1]))
self.assertTrue(torch.equal(bm.q_num_blocks, eager[2]))
self.assertTrue(torch.equal(bm.q_indices, eager[3]))

def test_dynamic_true_does_not_recompile_across_growing_kv(self):
"""KV_LEN grows by Q_LEN every TTT step; dynamic=True must keep one graph."""
Q = 512
# Warm up to lock the compiled artifact.
build_eagle3_block_mask(Q, Q, device=DEVICE)
dynamo.reset()
build_eagle3_block_mask(Q, Q, device=DEVICE)
before = dynamo.utils.counters["stats"].get("unique_graphs", 0)
for n_rounds in [2, 3, 4, 5]:
build_eagle3_block_mask(Q, Q * n_rounds, device=DEVICE)
after = dynamo.utils.counters["stats"].get("unique_graphs", 0)
# First call after dynamo.reset() compiles once (+1); growing KV must not add more.
self.assertLessEqual(
after - before,
1,
f"dynamic=True triggered {after - before} extra graphs across growing KV",
)


if __name__ == "__main__":
unittest.main(verbosity=2)
12 changes: 5 additions & 7 deletions tests/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,18 +252,16 @@ def test_eagle3_flex_mask(self):
query = norm_tensor((B, H, S, D), device="cuda", dtype=data_type)
key_cache = norm_tensor((B, H, KV_LEN, D), device="cuda", dtype=data_type)
value_cache = norm_tensor((B, H, KV_LEN, D), device="cuda", dtype=data_type)
seq_lengths = torch.tensor([S], device="cuda", dtype=torch.int32)
seq_lengths -= lck
block_mask = compile_friendly_create_block_mask(
mask_mod=generate_eagle3_mask(
seq_lengths=seq_lengths, Q_LEN=Q_LEN, KV_LEN=KV_LEN, lck=lck
),
mask_mod=generate_eagle3_mask(Q_LEN=Q_LEN, KV_LEN=KV_LEN, lck=lck),
B=1,
H=1,
Q_LEN=Q_LEN,
KV_LEN=KV_LEN,
device=query.device,
)
# PR #91 simplified generate_eagle3_mask to drop seq_lengths-aware shrinking;
# the mask is now the full causal+suffix pattern at every q row.
# fmt: off
expected_mask = torch.tensor([[[
[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
Expand All @@ -272,8 +270,8 @@ def test_eagle3_flex_mask(self):
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
]]], dtype=torch.int32).to(query.device)
# fmt: on
dense_mask = block_mask.to_dense()
Expand Down
32 changes: 9 additions & 23 deletions torchspec/models/draft/deepseek_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.nn.attention.flex_attention import flex_attention
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config

from torchspec.models.draft.base import Eagle3DraftModel
Expand All @@ -50,9 +50,8 @@
yarn_get_mscale,
)
from torchspec.models.ops.flex_attention import (
compile_friendly_create_block_mask,
compile_friendly_flex_attention,
generate_eagle3_mask,
eagle3_block_mask,
)
from torchspec.utils.logging import logger, print_with_rank

Expand Down Expand Up @@ -387,7 +386,8 @@ class DeepSeekMLAFlexAttention(DeepSeekMLAAttention):
cache_keys: [B, H, total_seq, qk_head_dim]
cache_values: [B, H, total_seq, v_head_dim]

EAGLE3 mask pattern is handled by generate_eagle3_mask + create_block_mask.
EAGLE3 mask pattern is handled by eagle3_block_mask (analytical when shapes
align to BLOCK_SIZE/Q_LEN, create_block_mask fallback otherwise).
"""

def forward(
Expand Down Expand Up @@ -419,29 +419,15 @@ def forward(
key_cache = key_states
value_cache = value_states

# Build EAGLE3 block mask from attention_mask (seq_lengths)
seq_lengths = attention_mask.sum(dim=-1)
seq_lengths -= lck
flex_attention_func = flex_attention if q_len <= 128 else compile_friendly_flex_attention

if q_len <= 128:
create_block_mask_func = create_block_mask
flex_attention_func = flex_attention
else:
create_block_mask_func = compile_friendly_create_block_mask
flex_attention_func = compile_friendly_flex_attention

block_mask = create_block_mask_func(
mask_mod=generate_eagle3_mask(
seq_lengths=seq_lengths,
Q_LEN=q_len,
KV_LEN=key_cache.shape[-2],
lck=lck,
),
B=bsz,
H=1, # Rely on broadcast
block_mask = eagle3_block_mask(
Q_LEN=q_len,
KV_LEN=key_cache.shape[-2],
B=bsz,
H=1, # Rely on broadcast
device=query_states.device,
lck=lck,
)

attn_output = flex_attention_func(
Expand Down
32 changes: 7 additions & 25 deletions torchspec/models/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

from torchspec.models.draft.base import Eagle3DraftModel
from torchspec.models.ops.flex_attention import (
compile_friendly_create_block_mask,
compile_friendly_flex_attention,
eagle3_block_mask,
generate_eagle3_mask,
)
from torchspec.utils.logging import logger, print_with_rank
Expand Down Expand Up @@ -937,9 +937,7 @@ def _build_eagle3_mask_pair(
raise ValueError(f"Unknown eagle3 mask mode {mode!r}")

# Always build a flex-compatible mask_mod for block-sparse iteration.
seq_lengths = torch.full((bsz,), q_len, dtype=torch.long, device=device)
mask_mod_flex = generate_eagle3_mask(
seq_lengths=seq_lengths,
Q_LEN=q_len,
KV_LEN=kv_len,
lck=lck,
Expand Down Expand Up @@ -1406,31 +1404,15 @@ def forward(
key_cache = key_states
value_cache = value_states

seq_lengths = attention_mask.sum(dim=-1)
# Shrink the attention mask to align with the padding to the right.
# This is equivalent to the shrinking logic in eagle3.py
seq_lengths -= lck
# TODO: Remove the usage of uncompiled create_block_mask after
# https://github.com/pytorch/pytorch/issues/160018
if q_len <= 128:
create_block_mask_func = create_block_mask
flex_attention_func = flex_attention
else:
create_block_mask_func = compile_friendly_create_block_mask
flex_attention_func = compile_friendly_flex_attention

block_mask = create_block_mask_func(
mask_mod=generate_eagle3_mask(
seq_lengths=seq_lengths,
Q_LEN=q_len,
KV_LEN=key_cache.shape[-2],
lck=lck,
),
B=bsz,
H=1, # Rely on broadcast
flex_attention_func = flex_attention if q_len <= 128 else compile_friendly_flex_attention

block_mask = eagle3_block_mask(
Q_LEN=q_len,
KV_LEN=key_cache.shape[-2],
B=bsz,
H=1, # Rely on broadcast
device=query_states.device,
lck=lck,
Comment thread
yubofredwang marked this conversation as resolved.
)
attn_output = flex_attention_func(
query=query_states,
Expand Down
4 changes: 4 additions & 0 deletions torchspec/models/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
# SOFTWARE.

from torchspec.models.ops.flex_attention import (
build_eagle3_block_mask,
compile_friendly_create_block_mask,
compile_friendly_flex_attention,
eagle3_block_mask,
generate_eagle3_mask,
)
from torchspec.models.ops.loss import compiled_forward_kl_loss
from torchspec.models.ops.loss_mask import compute_assistant_loss_mask

__all__ = [
"build_eagle3_block_mask",
"compile_friendly_create_block_mask",
"compile_friendly_flex_attention",
"eagle3_block_mask",
"generate_eagle3_mask",
"compiled_forward_kl_loss",
"compute_assistant_loss_mask",
Expand Down
Loading
Loading