Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 117 additions & 5 deletions tests/test_build_eagle3_block_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,91 @@ def test_assertions_on_invalid_shapes(self):
with self.assertRaises(AssertionError):
build_eagle3_block_mask(256, 384, device=DEVICE)

def _visited_blocks(self, num, idx):
"""Return set of (row, col) block pairs from a (B,H,rows,W) num/idx pair."""
rows = num.shape[-1]
seen = set()
for r in range(rows):
for c in range(num[0, 0, r].item()):
seen.add((r, idx[0, 0, r, c].item()))
return seen

def test_full_and_partial_blocks_match_create_block_mask(self):
"""Analytical builder must emit the same {full,partial} block split as
create_block_mask -- this is what FA4 BlockSparseTensorsTorch consumes."""
for Q, KV, BS in [
(256, 256, 128),
(256, 768, 128),
(1024, 4096, 128),
# FA4 Blackwell SM100+ uses Q_BS=256, KV_BS=128.
(512, 1024, (256, 128)),
(1024, 4096, (256, 128)),
]:
with self.subTest(Q=Q, KV=KV, BS=BS):
ref = create_block_mask(
generate_eagle3_mask(Q, KV),
B=1,
H=1,
Q_LEN=Q,
KV_LEN=KV,
device=DEVICE,
BLOCK_SIZE=BS,
)
ours = build_eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE, BLOCK_SIZE=BS)
# KV-direction (forward iteration).
self.assertEqual(
self._visited_blocks(ref.kv_num_blocks, ref.kv_indices),
self._visited_blocks(ours.kv_num_blocks, ours.kv_indices),
)
self.assertEqual(
self._visited_blocks(ref.full_kv_num_blocks, ref.full_kv_indices),
self._visited_blocks(ours.full_kv_num_blocks, ours.full_kv_indices),
)
# Q-direction (backward iteration).
self.assertEqual(
self._visited_blocks(ref.q_num_blocks, ref.q_indices),
self._visited_blocks(ours.q_num_blocks, ours.q_indices),
)
self.assertEqual(
self._visited_blocks(ref.full_q_num_blocks, ref.full_q_indices),
self._visited_blocks(ours.full_q_num_blocks, ours.full_q_indices),
)

def test_non_square_block_size_forward(self):
"""Q_BS != KV_BS (Blackwell case): output must match reference."""
torch.manual_seed(123)
B, H, D = 1, 4, 64
for Q, KV, BS in [(512, 1024, (256, 128)), (1024, 4096, (256, 128))]:
with self.subTest(Q=Q, KV=KV, BS=BS):
q = torch.randn(B, H, Q, D, device=DEVICE, dtype=torch.bfloat16)
k = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16)
v = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16)
ref = flex_attention(
q,
k,
v,
block_mask=create_block_mask(
generate_eagle3_mask(Q, KV),
B=1,
H=1,
Q_LEN=Q,
KV_LEN=KV,
device=DEVICE,
BLOCK_SIZE=BS,
),
)
ours = flex_attention(
q,
k,
v,
block_mask=build_eagle3_block_mask(
Q, KV, B=B, H=1, device=DEVICE, BLOCK_SIZE=BS
),
)
self.assertEqual(ref.shape, ours.shape)
self.assertFalse(ours.isnan().any())
self.assertLess((ref - ours).abs().max().item(), 1e-5)


class TestEagle3BlockMaskDispatcher(unittest.TestCase):
"""Dispatcher picks analytical when shapes align, otherwise falls back."""
Expand All @@ -125,6 +210,17 @@ def test_analytical_path_when_aligned(self):
ana = build_eagle3_block_mask(Q, KV, device=DEVICE)
self.assertTrue(torch.equal(disp.kv_indices, ana.kv_indices))
self.assertTrue(torch.equal(disp.q_indices, ana.q_indices))
self.assertTrue(torch.equal(disp.full_kv_indices, ana.full_kv_indices))
self.assertTrue(torch.equal(disp.full_q_indices, ana.full_q_indices))

def test_dispatcher_with_non_square_block_size(self):
"""Tuple BLOCK_SIZE (e.g. FA4 Blackwell 256x128) routes through analytical path."""
Q, KV, BS = 512, 2048, (256, 128)
disp = eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE, BLOCK_SIZE=BS)
ana = build_eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE, BLOCK_SIZE=BS)
self.assertEqual(disp.BLOCK_SIZE, (256, 128))
self.assertTrue(torch.equal(disp.kv_indices, ana.kv_indices))
self.assertTrue(torch.equal(disp.full_kv_indices, ana.full_kv_indices))

def test_fallback_path_matches_reference_mask_mod(self):
"""Fallback shapes (Q<BLOCK_SIZE, or KV%Q!=0) must produce the canonical mask."""
Expand Down Expand Up @@ -155,12 +251,28 @@ class TestCompiledTensorBuilder(unittest.TestCase):
def test_compiled_output_matches_eager(self):
for Q, KV in [(256, 256), (256, 768), (1024, 4096)]:
with self.subTest(Q=Q, KV=KV):
eager = _build_eagle3_block_mask_tensors(Q, KV, 1, 1, BLOCK_SIZE, DEVICE)
eager = _build_eagle3_block_mask_tensors(
Q, KV, 1, 1, BLOCK_SIZE, BLOCK_SIZE, DEVICE
)
(
e_kv_num,
e_kv_idx,
e_full_kv_num,
e_full_kv_idx,
e_q_num,
e_q_idx,
e_full_q_num,
e_full_q_idx,
) = eager
bm = build_eagle3_block_mask(Q, KV, device=DEVICE)
self.assertTrue(torch.equal(bm.kv_num_blocks, eager[0]))
self.assertTrue(torch.equal(bm.kv_indices, eager[1]))
self.assertTrue(torch.equal(bm.q_num_blocks, eager[2]))
self.assertTrue(torch.equal(bm.q_indices, eager[3]))
self.assertTrue(torch.equal(bm.kv_num_blocks, e_kv_num))
self.assertTrue(torch.equal(bm.kv_indices, e_kv_idx))
self.assertTrue(torch.equal(bm.full_kv_num_blocks, e_full_kv_num))
self.assertTrue(torch.equal(bm.full_kv_indices, e_full_kv_idx))
self.assertTrue(torch.equal(bm.q_num_blocks, e_q_num))
self.assertTrue(torch.equal(bm.q_indices, e_q_idx))
self.assertTrue(torch.equal(bm.full_q_num_blocks, e_full_q_num))
self.assertTrue(torch.equal(bm.full_q_indices, e_full_q_idx))

def test_dynamic_true_does_not_recompile_across_growing_kv(self):
"""KV_LEN grows by Q_LEN every TTT step; dynamic=True must keep one graph."""
Expand Down
Loading
Loading