diff --git a/tests/test_build_eagle3_block_mask.py b/tests/test_build_eagle3_block_mask.py index 138d902c..4d0ec1e8 100644 --- a/tests/test_build_eagle3_block_mask.py +++ b/tests/test_build_eagle3_block_mask.py @@ -114,6 +114,91 @@ def test_assertions_on_invalid_shapes(self): with self.assertRaises(AssertionError): build_eagle3_block_mask(256, 384, device=DEVICE) + def _visited_blocks(self, num, idx): + """Return set of (row, col) block pairs from a (B,H,rows,W) num/idx pair.""" + rows = num.shape[-1] + seen = set() + for r in range(rows): + for c in range(num[0, 0, r].item()): + seen.add((r, idx[0, 0, r, c].item())) + return seen + + def test_full_and_partial_blocks_match_create_block_mask(self): + """Analytical builder must emit the same {full,partial} block split as + create_block_mask -- this is what FA4 BlockSparseTensorsTorch consumes.""" + for Q, KV, BS in [ + (256, 256, 128), + (256, 768, 128), + (1024, 4096, 128), + # FA4 Blackwell SM100+ uses Q_BS=256, KV_BS=128. + (512, 1024, (256, 128)), + (1024, 4096, (256, 128)), + ]: + with self.subTest(Q=Q, KV=KV, BS=BS): + ref = create_block_mask( + generate_eagle3_mask(Q, KV), + B=1, + H=1, + Q_LEN=Q, + KV_LEN=KV, + device=DEVICE, + BLOCK_SIZE=BS, + ) + ours = build_eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE, BLOCK_SIZE=BS) + # KV-direction (forward iteration). + self.assertEqual( + self._visited_blocks(ref.kv_num_blocks, ref.kv_indices), + self._visited_blocks(ours.kv_num_blocks, ours.kv_indices), + ) + self.assertEqual( + self._visited_blocks(ref.full_kv_num_blocks, ref.full_kv_indices), + self._visited_blocks(ours.full_kv_num_blocks, ours.full_kv_indices), + ) + # Q-direction (backward iteration). + self.assertEqual( + self._visited_blocks(ref.q_num_blocks, ref.q_indices), + self._visited_blocks(ours.q_num_blocks, ours.q_indices), + ) + self.assertEqual( + self._visited_blocks(ref.full_q_num_blocks, ref.full_q_indices), + self._visited_blocks(ours.full_q_num_blocks, ours.full_q_indices), + ) + + def test_non_square_block_size_forward(self): + """Q_BS != KV_BS (Blackwell case): output must match reference.""" + torch.manual_seed(123) + B, H, D = 1, 4, 64 + for Q, KV, BS in [(512, 1024, (256, 128)), (1024, 4096, (256, 128))]: + with self.subTest(Q=Q, KV=KV, BS=BS): + q = torch.randn(B, H, Q, D, device=DEVICE, dtype=torch.bfloat16) + k = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16) + v = torch.randn(B, H, KV, D, device=DEVICE, dtype=torch.bfloat16) + ref = flex_attention( + q, + k, + v, + block_mask=create_block_mask( + generate_eagle3_mask(Q, KV), + B=1, + H=1, + Q_LEN=Q, + KV_LEN=KV, + device=DEVICE, + BLOCK_SIZE=BS, + ), + ) + ours = flex_attention( + q, + k, + v, + block_mask=build_eagle3_block_mask( + Q, KV, B=B, H=1, device=DEVICE, BLOCK_SIZE=BS + ), + ) + self.assertEqual(ref.shape, ours.shape) + self.assertFalse(ours.isnan().any()) + self.assertLess((ref - ours).abs().max().item(), 1e-5) + class TestEagle3BlockMaskDispatcher(unittest.TestCase): """Dispatcher picks analytical when shapes align, otherwise falls back.""" @@ -125,6 +210,17 @@ def test_analytical_path_when_aligned(self): ana = build_eagle3_block_mask(Q, KV, device=DEVICE) self.assertTrue(torch.equal(disp.kv_indices, ana.kv_indices)) self.assertTrue(torch.equal(disp.q_indices, ana.q_indices)) + self.assertTrue(torch.equal(disp.full_kv_indices, ana.full_kv_indices)) + self.assertTrue(torch.equal(disp.full_q_indices, ana.full_q_indices)) + + def test_dispatcher_with_non_square_block_size(self): + """Tuple BLOCK_SIZE (e.g. FA4 Blackwell 256x128) routes through analytical path.""" + Q, KV, BS = 512, 2048, (256, 128) + disp = eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE, BLOCK_SIZE=BS) + ana = build_eagle3_block_mask(Q, KV, B=1, H=1, device=DEVICE, BLOCK_SIZE=BS) + self.assertEqual(disp.BLOCK_SIZE, (256, 128)) + self.assertTrue(torch.equal(disp.kv_indices, ana.kv_indices)) + self.assertTrue(torch.equal(disp.full_kv_indices, ana.full_kv_indices)) def test_fallback_path_matches_reference_mask_mod(self): """Fallback shapes (Q FULL (all True, skip mask_mod) + * Diagonal slab (kj in [qi*r, (qi+1)*r)) -> PARTIAL (lower-triangular) + * Above diagonal -> empty (omitted) + Suffix region (kv >= Q_LEN), one round per Q_LEN cols: + * Per round, Q-block qi has r PARTIAL blocks (one per BK-slot within Q-block). + * Suffix blocks are never FULL (mask is at most a thin diagonal). """ - n_q = Q_LEN // BLOCK_SIZE - n_kv = KV_LEN // BLOCK_SIZE + r = Q_BS // KV_BS # KV-blocks per Q-block (1 on Hopper, 2 on Blackwell SM100+) + n_q = Q_LEN // Q_BS + n_kv = KV_LEN // KV_BS + n_kv_causal = Q_LEN // KV_BS # KV-blocks covering the first (causal) round n_rounds = KV_LEN // Q_LEN - # Row qi attends to causal cols 0..qi and one diagonal col per suffix round - # at (col-qi)*n_q + qi; total qi + n_rounds entries per row. - max_kv_per_row = n_q + n_rounds - 1 + # ---- KV direction (forward iteration: per Q-block, list KV-blocks) ----- + # Both partial and full kv_indices must be width n_kv -- this is what FA4 + # BlockSparseTensorsTorch + flex_attention's create_block_mask emit; FA4's + # infer_block_sparse_expected_shapes asserts on it explicitly. Only the + # first kv_num/full_kv_num entries per row are read; the rest are padding. qi = torch.arange(n_q, device=device, dtype=torch.int32) - col = torch.arange(max_kv_per_row, device=device, dtype=torch.int32) qi_b = qi.unsqueeze(1) - col_b = col.unsqueeze(0) - is_causal = col_b <= qi_b - causal_kv = col_b.expand(n_q, max_kv_per_row) - suffix_kv = (col_b - qi_b) * n_q + qi_b - kv_idx_2d = torch.where(is_causal, causal_kv, suffix_kv) - valid = col_b < (qi_b + n_rounds) - kv_idx_2d = torch.where(valid, kv_idx_2d, torch.zeros_like(kv_idx_2d)) - kv_num_1d = (qi + n_rounds).to(torch.int32) - - # Column ki: r=ki//n_q, pos=ki%n_q. r==0 -> q in [pos, n_q); r>=1 -> q==pos. - max_q_per_col = n_q - ki = torch.arange(n_kv, device=device, dtype=torch.int32) - col_q = torch.arange(max_q_per_col, device=device, dtype=torch.int32) - r = ki // n_q - pos = ki % n_q - r_b = r.unsqueeze(1) - pos_b = pos.unsqueeze(1) - col_q_b = col_q.unsqueeze(0) - q_idx_2d = torch.where(r_b == 0, pos_b + col_q_b, pos_b.expand(n_kv, max_q_per_col)) - q_num_1d = torch.where(r == 0, n_q - pos, torch.ones_like(pos)).to(torch.int32) - valid_q = col_q_b < q_num_1d.unsqueeze(1) - q_idx_2d = torch.where(valid_q, q_idx_2d, torch.zeros_like(q_idx_2d)) + col_kv = torch.arange(n_kv, device=device, dtype=torch.int32).unsqueeze(0) + + # FULL: kj in [0, qi*r) (causal-below-diagonal). + full_kv_num_1d = (qi * r).to(torch.int32) + full_kv_idx_2d = torch.where( + col_kv < full_kv_num_1d.unsqueeze(1), + col_kv.expand(n_q, n_kv), + torch.zeros_like(col_kv).expand(n_q, n_kv), + ) + + # PARTIAL: r diagonal + r * (n_rounds - 1) suffix per Q-block. + # Round s in [0, n_rounds): kj in [qi*r + s*n_kv_causal, (qi+1)*r + s*n_kv_causal). + partial_count = r * n_rounds + s_b = col_kv // r # which round + sub_b = col_kv % r # offset within round + kv_idx_2d = torch.where( + col_kv < partial_count, + qi_b * r + sub_b + s_b * n_kv_causal, + torch.zeros_like(col_kv).expand(n_q, n_kv), + ) + kv_num_1d = torch.full((n_q,), partial_count, dtype=torch.int32, device=device) + + # ---- Q direction (backward iteration: per KV-block, list Q-blocks) ----- + # Same n_q-wide padding for FA4 backward block-sparsity. + kj = torch.arange(n_kv, device=device, dtype=torch.int32) + kj_b = kj.unsqueeze(1) + col_q = torch.arange(n_q, device=device, dtype=torch.int32).unsqueeze(0) + is_causal_kj = kj < n_kv_causal + + # PARTIAL: every kj has exactly 1 partial Q-block (the diagonal Q-block). + # Causal kj: qi = kj // r. + # Suffix kj: qi = (kj % n_kv_causal) // r. + q_idx_diag = torch.where(is_causal_kj, kj // r, (kj % n_kv_causal) // r) + q_idx_2d = torch.where( + col_q == 0, + q_idx_diag.unsqueeze(1).expand(n_kv, n_q), + torch.zeros_like(col_q).expand(n_kv, n_q), + ) + q_num_1d = torch.ones(n_kv, dtype=torch.int32, device=device) + + # FULL: causal kj has (n_q - 1 - kj//r) full Q-blocks (qi > kj//r); suffix has 0. + diag_qi = kj_b // r + full_q_num_per_causal = (n_q - 1) - (kj // r) + full_q_num_1d = torch.where(is_causal_kj, full_q_num_per_causal, torch.zeros_like(kj)).to( + torch.int32 + ) + full_q_idx_2d = torch.where( + col_q < full_q_num_1d.unsqueeze(1), + diag_qi + 1 + col_q, + torch.zeros_like(col_q).expand(n_kv, n_q), + ) # flex_attention iterates these directly; force contiguous storage. - kv_num = kv_num_1d.unsqueeze(0).unsqueeze(0).expand(B, H, n_q).contiguous() - kv_idx = kv_idx_2d.unsqueeze(0).unsqueeze(0).expand(B, H, n_q, max_kv_per_row).contiguous() - q_num = q_num_1d.unsqueeze(0).unsqueeze(0).expand(B, H, n_kv).contiguous() - q_idx = q_idx_2d.unsqueeze(0).unsqueeze(0).expand(B, H, n_kv, max_q_per_col).contiguous() - return kv_num, kv_idx, q_num, q_idx + def _expand_1d(t): + return t.unsqueeze(0).unsqueeze(0).expand(B, H, -1).contiguous() + + def _expand_2d(t): + return t.unsqueeze(0).unsqueeze(0).expand(B, H, -1, -1).contiguous() + + return ( + _expand_1d(kv_num_1d), + _expand_2d(kv_idx_2d), + _expand_1d(full_kv_num_1d), + _expand_2d(full_kv_idx_2d), + _expand_1d(q_num_1d), + _expand_2d(q_idx_2d), + _expand_1d(full_q_num_1d), + _expand_2d(full_q_idx_2d), + ) # dynamic=True so KV_LEN growing per TTT step doesn't recompile; inductor's @@ -208,13 +271,21 @@ def _get_compiled_build_tensors(): return _compiled_build_tensors +def _normalize_block_size(BLOCK_SIZE) -> tuple[int, int]: + """Accept BLOCK_SIZE as int or (Q_BS, KV_BS) tuple; return (Q_BS, KV_BS).""" + if isinstance(BLOCK_SIZE, int): + return BLOCK_SIZE, BLOCK_SIZE + Q_BS, KV_BS = BLOCK_SIZE + return int(Q_BS), int(KV_BS) + + def build_eagle3_block_mask( Q_LEN: int, KV_LEN: int, B: int = 1, H: int = 1, device: torch.device = "cuda", - BLOCK_SIZE: int = 128, + BLOCK_SIZE: "int | tuple[int, int]" = 128, ) -> "BlockMask": """Build Eagle3 BlockMask analytically -- O(num_blocks) memory and time. @@ -223,11 +294,19 @@ def build_eagle3_block_mask( directly from the known Eagle3 structure (causal first round + diagonal suffix rounds), so peak memory drops to a few MB. - Requires Q_LEN, KV_LEN multiples of BLOCK_SIZE and KV_LEN a multiple of - Q_LEN. Use ``eagle3_block_mask`` for the dispatching wrapper that falls + Requires Q_LEN multiple of Q_BS, KV_LEN multiple of KV_BS, KV_BS dividing + Q_BS, and KV_LEN a multiple of Q_LEN. ``BLOCK_SIZE`` accepts either an + int (square Q_BS=KV_BS) or a ``(Q_BS, KV_BS)`` tuple (FA4 Blackwell uses + 256x128). Use ``eagle3_block_mask`` for the dispatching wrapper that falls back to create_block_mask otherwise. + + Populates both PARTIAL (kv_*/q_*) and FULL (full_kv_*/full_q_*) block + tensors so flex_attention and FA4 BlockSparseTensorsTorch can both fast-path + fully-causal-below-diagonal blocks (skip mask_mod evaluation). """ - assert Q_LEN % BLOCK_SIZE == 0 and KV_LEN % BLOCK_SIZE == 0 + Q_BS, KV_BS = _normalize_block_size(BLOCK_SIZE) + assert Q_LEN % Q_BS == 0 and KV_LEN % KV_BS == 0 + assert Q_BS % KV_BS == 0, f"Q_BS ({Q_BS}) must be a multiple of KV_BS ({KV_BS})" assert KV_LEN % Q_LEN == 0, ( "build_eagle3_block_mask requires KV_LEN to be a multiple of Q_LEN; " f"got Q_LEN={Q_LEN}, KV_LEN={KV_LEN}" @@ -239,7 +318,9 @@ def build_eagle3_block_mask( if is_torchdynamo_compiling() else _get_compiled_build_tensors() ) - kv_num, kv_idx, q_num, q_idx = builder(Q_LEN, KV_LEN, B, H, BLOCK_SIZE, device) + kv_num, kv_idx, full_kv_num, full_kv_idx, q_num, q_idx, full_q_num, full_q_idx = builder( + Q_LEN, KV_LEN, B, H, Q_BS, KV_BS, device + ) def mask_mod(b, h, q, kv): causal = (kv < Q_LEN) & (q >= kv) @@ -250,13 +331,13 @@ def mask_mod(b, h, q, kv): seq_lengths=(Q_LEN, KV_LEN), kv_num_blocks=kv_num, kv_indices=kv_idx, - full_kv_num_blocks=None, - full_kv_indices=None, + full_kv_num_blocks=full_kv_num, + full_kv_indices=full_kv_idx, q_num_blocks=q_num, q_indices=q_idx, - full_q_num_blocks=None, - full_q_indices=None, - BLOCK_SIZE=(BLOCK_SIZE, BLOCK_SIZE), + full_q_num_blocks=full_q_num, + full_q_indices=full_q_idx, + BLOCK_SIZE=(Q_BS, KV_BS), mask_mod=mask_mod, ) @@ -268,17 +349,17 @@ def eagle3_block_mask( B: int = 1, H: int = 1, device: torch.device = "cuda", - BLOCK_SIZE: int = 128, + BLOCK_SIZE: "int | tuple[int, int]" = 128, lck: int = 0, ) -> "BlockMask": """Eagle3 block-mask dispatcher -- analytical when possible, fallback otherwise. Eagle3 training appends one full Q_LEN-sized round per step, so in normal training the analytical builder's preconditions - ``(Q_LEN % BLOCK_SIZE == 0 and KV_LEN % Q_LEN == 0)`` always hold. The - create_block_mask fallback only triggers for tests/edge cases (tiny - sequence lengths, non-aligned shapes), where its O(Q*KV) memory cost is - irrelevant. + ``(Q_LEN % Q_BS == 0 and KV_LEN % KV_BS == 0 and KV_LEN % Q_LEN == 0)`` + always hold. The create_block_mask fallback only triggers for tests/edge + cases (tiny sequence lengths, non-aligned shapes), where its O(Q*KV) + memory cost is irrelevant. Args: Q_LEN: query length (current round). @@ -286,7 +367,9 @@ def eagle3_block_mask( B: batch size for the BlockMask (broadcast-friendly when 1). H: head count for the BlockMask (broadcast-friendly when 1). device: target device. - BLOCK_SIZE: flex_attention block size; defaults to 128. + BLOCK_SIZE: flex_attention block size; ``int`` (square) or + ``(Q_BS, KV_BS)`` tuple. FA4 Blackwell SM100+ uses ``(256, 128)``. + Defaults to 128 (square). lck: number of completed rounds; only used to name the fallback mask_mod for debug clarity. @@ -294,7 +377,10 @@ def eagle3_block_mask( A flex_attention BlockMask implementing the Eagle3 causal+suffix pattern. """ - use_analytical = Q_LEN % BLOCK_SIZE == 0 and KV_LEN % BLOCK_SIZE == 0 and KV_LEN % Q_LEN == 0 + Q_BS, KV_BS = _normalize_block_size(BLOCK_SIZE) + use_analytical = ( + Q_LEN % Q_BS == 0 and KV_LEN % KV_BS == 0 and Q_BS % KV_BS == 0 and KV_LEN % Q_LEN == 0 + ) if use_analytical: return build_eagle3_block_mask( Q_LEN=Q_LEN, @@ -302,7 +388,7 @@ def eagle3_block_mask( B=B, H=H, device=device, - BLOCK_SIZE=BLOCK_SIZE, + BLOCK_SIZE=(Q_BS, KV_BS), ) # Fallback for non-aligned shapes (typically only seen in tests). @@ -316,4 +402,5 @@ def eagle3_block_mask( Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device, + BLOCK_SIZE=(Q_BS, KV_BS), )