Fix CP crash with GQA + asymmetric KV head dims (#2868)#2901
Conversation
Signed-off-by: Arkadii Be <beccohov@gmail.com>
Greptile SummaryThis PR fixes a shape mismatch crash in Confidence Score: 5/5Safe to merge — targeted one-variable fix with correct placement, backward-compatible under strict MLA (h_q == h_kv), and covered by new regression tests. All changed sites correctly distinguish the output shape (h_q heads) from the KV shape (h_kv heads). No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["AttnFuncWithCPAndKVP2P.forward()"] --> B["Causal reshape\n(bshd/sbhd only)"]
B --> C["v_shape = v.shape\nmla_out_shape = (*q.shape[:-1], v.shape[-1])"]
C --> D{enable_mla?}
D -- No --> E["out allocation\nuses q.shape (h_q, d_q)"]
D -- Yes --> F{qkv_format?}
F -- thd --> G["out = tensor.new_zeros(mla_out_shape)\n(h_q, d_v) ✅"]
F -- bshd/sbhd --> H["out = correction_init(...)\n.view(mla_out_shape)\n(h_q, d_v) ✅"]
G --> I["ctx.mla_out_shape saved"]
H --> I
I --> J["backward()"]
J --> K{enable_mla?}
K -- Yes --> L["out.view(*ctx.mla_out_shape)\ndout.view(*ctx.mla_out_shape)\n(h_q, d_v) ✅"]
K -- No --> M["out/dout.view(*q.shape)"]
style G fill:#d4edda
style H fill:#d4edda
style L fill:#d4edda
Reviews (1): Last reviewed commit: "Fix CP crash with GQA + asymmetric KV he..." | Re-trigger Greptile |
| k_shape = k.shape | ||
| k_numel = k.numel() | ||
| v_shape = v.shape | ||
| mla_out_shape = (*q.shape[:-1], v.shape[-1]) |
There was a problem hiding this comment.
mla_out_shape computed before the causal reshape for THD format
mla_out_shape = (*q.shape[:-1], v.shape[-1]) is placed at line 1564, which is after the causal reshape for bshd/sbhd (lines 1449-1455), so those formats get the correct 5-D layout. However, in the thd path there is no causal reshape — q retains its original [t, h_q, d_q] shape in all cases — so mla_out_shape computes correctly as (t, h_q, d_v) there too.
This is actually fine and the comment in the PR description ("Computed after the causal reshape so it parallels v_shape's 5D layout when causal") is accurate. Noting this for reviewers: the placement is intentional and correct, no action needed.
|
Thanks for raising this issue. I believe this has been incidentally fixed by #2719. |
Description
AttnFuncWithCPAndKVP2Ptreatsk.shape[-1] != v.shape[-1]as "MLA mode" andassumes the attention output shares V's shape. That holds for strict MLA
(h_q == h_kv) but not when MLA-style attention is combined with GQA
(h_q != h_kv), e.g. MiMo-V2-Flash with
num_attention_heads=64, num_kv_heads=4, head_dim_qk=192, head_dim_v=128.In that case
out.view(v_shape)andzeros_like(v).view(v_shape)fail with asize/shape mismatch because the attention output has
h_qheads while V hash_kvheads.Fixes #2868
Distinguish V's shape from the attention-output shape:
v_shape(unchanged): used for views of V and dV —[..., h_kv, d_v].mla_out_shape = (*q.shape[:-1], v.shape[-1])(new): used for views ofout / dout —
[..., h_q, d_v].Computed after the causal reshape so it parallels
v_shape's 5D layout whencausal. Under strict MLA (h_q == h_kv)
mla_out_shape == v_shape, so theexisting MLA path is unchanged.
Type of change
Changes
mla_out_shapeinAttnFuncWithCPAndKVP2P.forwardand save it onctx, separate from the existingv_shapewhich continues to describe V/dV.mla_out_shapefor out/dout views in forward (init zeros, post-correctionreshape) and backward (out/dout rehydration).
.new_zeros(mla_out_shape)instead oftorch.zeros_like(v).view(v_shape), so allocation size matchesh_q * d_vinstead of
h_kv * d_v.tests/pytorch/attention/test_attention_with_cp.pycombining GQA + asymmetric qk/v head dims:
model_configs_flash_attn:cp_3_4(causal),cp_3_5(non-causal).model_configs_fused_attn:cp_3_5(causal),cp_3_6(non-causal).test_essentialCI subset as regressionsentinels.
strict MLA and MLA-style attention combined with GQA.
Checklist: