-
Notifications
You must be signed in to change notification settings - Fork 22
feat: analytical Eagle3 BlockMask builder (O(num_blocks) memory) #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| """Tests for build_eagle3_block_mask -- the analytical Eagle3 BlockMask builder.""" | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| import torch._dynamo as dynamo | ||
| from torch.nn.attention.flex_attention import create_block_mask, flex_attention | ||
|
|
||
| from torchspec.models.ops.flex_attention import ( | ||
| _build_eagle3_block_mask_tensors, | ||
| build_eagle3_block_mask, | ||
| eagle3_block_mask, | ||
| generate_eagle3_mask, | ||
| ) | ||
|
|
||
| DEVICE = "cuda" | ||
| BLOCK_SIZE = 128 | ||
|
|
||
|
|
||
| def dense_from_mod(Q_LEN, KV_LEN, mask_mod, batch_idx=0): | ||
| """Materialise a (Q_LEN, KV_LEN) bool grid from a mask_mod or BlockMask.""" | ||
| qi = torch.arange(Q_LEN, device=DEVICE).unsqueeze(1) | ||
| ki = torch.arange(KV_LEN, device=DEVICE).unsqueeze(0) | ||
| b = torch.full_like(qi, batch_idx) | ||
| h = torch.zeros_like(qi) | ||
| fn = mask_mod.mask_mod if hasattr(mask_mod, "mask_mod") else mask_mod | ||
| return fn(b, h, qi, ki).bool() | ||
|
|
||
|
|
||
| def reference_block_mask(Q_LEN, KV_LEN, B=1, H=1): | ||
| """create_block_mask using the production simplified mask_mod.""" | ||
| return create_block_mask( | ||
| generate_eagle3_mask(Q_LEN, KV_LEN), | ||
| B=B, | ||
| H=H, | ||
| Q_LEN=Q_LEN, | ||
| KV_LEN=KV_LEN, | ||
| device=DEVICE, | ||
| ) | ||
|
|
||
|
|
||
| # Sizes covering single round, short-multi-round, and aligned-multi-round cases. | ||
| SHAPES = [(256, 256), (256, 768), (256, 1280), (1024, 4096)] | ||
|
|
||
|
|
||
| class TestBuildEagle3BlockMask(unittest.TestCase): | ||
| """Analytical builder must produce a mask equivalent to create_block_mask.""" | ||
|
|
||
| def test_dense_mask_matches_reference(self): | ||
| for Q, KV in SHAPES: | ||
| with self.subTest(Q=Q, KV=KV): | ||
| ref = dense_from_mod(Q, KV, reference_block_mask(Q, KV)) | ||
| ours = dense_from_mod(Q, KV, build_eagle3_block_mask(Q, KV, device=DEVICE)) | ||
| self.assertTrue(torch.equal(ref, ours)) | ||
|
|
||
| def test_forward_matches_reference(self): | ||
| torch.manual_seed(42) | ||
| B, H, D = 1, 4, 64 | ||
| for Q, KV in SHAPES: | ||
| with self.subTest(Q=Q, KV=KV): | ||
| q = torch.randn(B, H, Q, D, device=DEVICE, dtype=torch.bfloat16) | ||
| k = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16) | ||
| v = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16) | ||
| ref = flex_attention(q, k, v, block_mask=reference_block_mask(Q, KV)) | ||
| ours = flex_attention(q, k, v, block_mask=build_eagle3_block_mask(Q, KV, B=B)) | ||
| self.assertEqual(ref.shape, ours.shape) | ||
| self.assertFalse(ours.isnan().any()) | ||
| self.assertLess((ref - ours).abs().max().item(), 1e-5) | ||
|
|
||
| def test_backward_gradients_match_reference(self): | ||
| torch.manual_seed(42) | ||
| B, H, D, Q, KV = 1, 4, 64, 256, 768 | ||
|
|
||
| def grads(mask): | ||
| q = torch.randn(B, H, Q, D, device=DEVICE, dtype=torch.float32, requires_grad=True) | ||
| k = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.float32, requires_grad=True) | ||
| v = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.float32, requires_grad=True) | ||
| flex_attention(q, k, v, block_mask=mask).sum().backward() | ||
| return q.grad, k.grad, v.grad | ||
|
|
||
| torch.manual_seed(42) | ||
| gq_r, gk_r, gv_r = grads(reference_block_mask(Q, KV)) | ||
| torch.manual_seed(42) | ||
| gq_o, gk_o, gv_o = grads(build_eagle3_block_mask(Q, KV, B=B)) | ||
| for name, gr, go in [("q", gq_r, gq_o), ("k", gk_r, gk_o), ("v", gv_r, gv_o)]: | ||
| self.assertLess((gr - go).abs().max().item(), 1e-4, f"grad mismatch on {name}") | ||
|
|
||
| def test_gqa_broadcast(self): | ||
| """H=1 mask broadcasts over multi-Q-head GQA without NaN.""" | ||
| torch.manual_seed(0) | ||
| B, Qh, KVh, D, Q, KV = 1, 8, 2, 64, 256, 768 | ||
| q = torch.randn(B, Qh, Q, D, device=DEVICE, dtype=torch.bfloat16) | ||
| k = torch.randn(B, KVh, KV, D, device=DEVICE, dtype=torch.bfloat16) | ||
| v = torch.randn(B, KVh, KV, D, device=DEVICE, dtype=torch.bfloat16) | ||
| bm = build_eagle3_block_mask(Q, KV, B=B, device=DEVICE) | ||
| out = flex_attention(q, k, v, block_mask=bm, enable_gqa=True) | ||
| self.assertEqual(out.shape, (B, Qh, Q, D)) | ||
| self.assertFalse(out.isnan().any()) | ||
|
|
||
| def test_memory_is_negligible(self): | ||
| """Original create_block_mask costs ~112 GB at Q=49K; this must stay in MB.""" | ||
| Q, KV = 4096, 4096 * 5 | ||
| torch.cuda.reset_peak_memory_stats() | ||
| before = torch.cuda.memory_allocated() | ||
| build_eagle3_block_mask(Q, KV, device=DEVICE) | ||
| mem_mb = (torch.cuda.max_memory_allocated() - before) / 1024**2 | ||
| self.assertLess(mem_mb, 10.0, f"used {mem_mb:.1f} MB") | ||
|
|
||
| def test_assertions_on_invalid_shapes(self): | ||
| # not divisible by BLOCK_SIZE | ||
| with self.assertRaises(AssertionError): | ||
| build_eagle3_block_mask(100, 300, device=DEVICE) | ||
| # KV not a Q-multiple | ||
| with self.assertRaises(AssertionError): | ||
| build_eagle3_block_mask(256, 384, device=DEVICE) | ||
|
|
||
|
|
||
| class TestEagle3BlockMaskDispatcher(unittest.TestCase): | ||
| """Dispatcher picks analytical when shapes align, otherwise falls back.""" | ||
|
|
||
| def test_analytical_path_when_aligned(self): | ||
| for Q, KV in [(256, 256), (256, 768)]: | ||
| with self.subTest(Q=Q, KV=KV): | ||
| disp = eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE) | ||
| ana = build_eagle3_block_mask(Q, KV, device=DEVICE) | ||
| self.assertTrue(torch.equal(disp.kv_indices, ana.kv_indices)) | ||
| self.assertTrue(torch.equal(disp.q_indices, ana.q_indices)) | ||
|
|
||
| def test_fallback_path_matches_reference_mask_mod(self): | ||
| """Fallback shapes (Q<BLOCK_SIZE, or KV%Q!=0) must produce the canonical mask.""" | ||
| for Q, KV in [(64, 64), (256, 384)]: | ||
| with self.subTest(Q=Q, KV=KV): | ||
| bm = eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE) | ||
| expected = dense_from_mod(Q, KV, generate_eagle3_mask(Q, KV)) | ||
| self.assertTrue(torch.equal(dense_from_mod(Q, KV, bm), expected)) | ||
|
|
||
| def test_dispatcher_forward_matches_reference(self): | ||
| torch.manual_seed(0) | ||
| B, H, D = 1, 4, 64 | ||
| for Q, KV in [(256, 256), (256, 768)]: | ||
| with self.subTest(Q=Q, KV=KV): | ||
| q = torch.randn(B, H, Q, D, device=DEVICE, dtype=torch.bfloat16) | ||
| k = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16) | ||
| v = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16) | ||
| ref = flex_attention(q, k, v, block_mask=reference_block_mask(Q, KV, B=B)) | ||
| disp = flex_attention( | ||
| q, k, v, block_mask=eagle3_block_mask(Q, KV, B=B, H=1, device=DEVICE) | ||
| ) | ||
| self.assertLess((ref - disp).abs().max().item(), 1e-5) | ||
|
|
||
|
|
||
| class TestCompiledTensorBuilder(unittest.TestCase): | ||
| """build_eagle3_block_mask routes through torch.compile -- verify behaviour.""" | ||
|
|
||
| def test_compiled_output_matches_eager(self): | ||
| for Q, KV in [(256, 256), (256, 768), (1024, 4096)]: | ||
| with self.subTest(Q=Q, KV=KV): | ||
| eager = _build_eagle3_block_mask_tensors(Q, KV, 1, 1, BLOCK_SIZE, DEVICE) | ||
| bm = build_eagle3_block_mask(Q, KV, device=DEVICE) | ||
| self.assertTrue(torch.equal(bm.kv_num_blocks, eager[0])) | ||
| self.assertTrue(torch.equal(bm.kv_indices, eager[1])) | ||
| self.assertTrue(torch.equal(bm.q_num_blocks, eager[2])) | ||
| self.assertTrue(torch.equal(bm.q_indices, eager[3])) | ||
|
|
||
| def test_dynamic_true_does_not_recompile_across_growing_kv(self): | ||
| """KV_LEN grows by Q_LEN every TTT step; dynamic=True must keep one graph.""" | ||
| Q = 512 | ||
| # Warm up to lock the compiled artifact. | ||
| build_eagle3_block_mask(Q, Q, device=DEVICE) | ||
| dynamo.reset() | ||
| build_eagle3_block_mask(Q, Q, device=DEVICE) | ||
| before = dynamo.utils.counters["stats"].get("unique_graphs", 0) | ||
| for n_rounds in [2, 3, 4, 5]: | ||
| build_eagle3_block_mask(Q, Q * n_rounds, device=DEVICE) | ||
| after = dynamo.utils.counters["stats"].get("unique_graphs", 0) | ||
| # First call after dynamo.reset() compiles once (+1); growing KV must not add more. | ||
| self.assertLessEqual( | ||
| after - before, | ||
| 1, | ||
| f"dynamic=True triggered {after - before} extra graphs across growing KV", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main(verbosity=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.