Skip to content

Fix CP crash with GQA + asymmetric KV head dims (#2868)#2901

Closed
beccohov wants to merge 1 commit into
NVIDIA:mainfrom
beccohov:bugfix/cp-gqa-asymmetric-kv-head-dims
Closed

Fix CP crash with GQA + asymmetric KV head dims (#2868)#2901
beccohov wants to merge 1 commit into
NVIDIA:mainfrom
beccohov:bugfix/cp-gqa-asymmetric-kv-head-dims

Conversation

@beccohov
Copy link
Copy Markdown

Description

AttnFuncWithCPAndKVP2P treats k.shape[-1] != v.shape[-1] as "MLA mode" and
assumes the attention output shares V's shape. That holds for strict MLA
(h_q == h_kv) but not when MLA-style attention is combined with GQA
(h_q != h_kv), e.g. MiMo-V2-Flash with num_attention_heads=64, num_kv_heads=4, head_dim_qk=192, head_dim_v=128.

In that case out.view(v_shape) and zeros_like(v).view(v_shape) fail with a
size/shape mismatch because the attention output has h_q heads while V has
h_kv heads.

Fixes #2868

Distinguish V's shape from the attention-output shape:

  • v_shape (unchanged): used for views of V and dV — [..., h_kv, d_v].
  • mla_out_shape = (*q.shape[:-1], v.shape[-1]) (new): used for views of
    out / dout — [..., h_q, d_v].

Computed after the causal reshape so it parallels v_shape's 5D layout when
causal. Under strict MLA (h_q == h_kv) mla_out_shape == v_shape, so the
existing MLA path is unchanged.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Introduce mla_out_shape in AttnFuncWithCPAndKVP2P.forward and save it on
    ctx, separate from the existing v_shape which continues to describe V/dV.
  • Use mla_out_shape for out/dout views in forward (init zeros, post-correction
    reshape) and backward (out/dout rehydration).
  • Allocate the first-step output with .new_zeros(mla_out_shape) instead of
    torch.zeros_like(v).view(v_shape), so allocation size matches h_q * d_v
    instead of h_kv * d_v.
  • Add tests in tests/pytorch/attention/test_attention_with_cp.py
    combining GQA + asymmetric qk/v head dims:
    • model_configs_flash_attn: cp_3_4 (causal), cp_3_5 (non-causal).
    • model_configs_fused_attn: cp_3_5 (causal), cp_3_6 (non-causal).
    • Causal cases added to the test_essential CI subset as regression
      sentinels.
  • Clarify in the code and test-config comments that this path covers both
    strict MLA and MLA-style attention combined with GQA.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Arkadii Be <beccohov@gmail.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 19, 2026

Greptile Summary

This PR fixes a shape mismatch crash in AttnFuncWithCPAndKVP2P when combining GQA (h_q != h_kv) with MLA-style asymmetric key/value head dimensions. The fix introduces mla_out_shape = (*q.shape[:-1], v.shape[-1]) — capturing the attention output's h_q-head shape versus the existing v_shape's h_kv-head shape — and uses it consistently for output tensor allocation and reshaping in both forward and backward, while leaving all K/V buffer operations on v_shape unchanged.

Confidence Score: 5/5

Safe to merge — targeted one-variable fix with correct placement, backward-compatible under strict MLA (h_q == h_kv), and covered by new regression tests.

All changed sites correctly distinguish the output shape (h_q heads) from the KV shape (h_kv heads). v_shape is still used exclusively for K/V buffer views, and mla_out_shape is used for out/dout in both forward and backward. When h_q == h_kv the two shapes are identical, so existing MLA tests remain valid. New test configs exercise causal and non-causal GQA+MLA paths for both flash-attn and fused-attn backends. No P0/P1 findings.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Introduces mla_out_shape = (*q.shape[:-1], v.shape[-1]) to decouple the attention-output shape (h_q heads) from the V/dV shape (h_kv heads) in the GQA+MLA path; correctly applied in forward (THD allocation, bshd/sbhd reshape) and backward (out/dout rehydration).
tests/pytorch/attention/test_attention_with_cp.py Adds four new model configs (two for flash-attn, two for fused-attn) exercising GQA + asymmetric qk/v head dims; causal variants included in the test_essential CI subset.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["AttnFuncWithCPAndKVP2P.forward()"] --> B["Causal reshape\n(bshd/sbhd only)"]
    B --> C["v_shape = v.shape\nmla_out_shape = (*q.shape[:-1], v.shape[-1])"]
    C --> D{enable_mla?}
    D -- No --> E["out allocation\nuses q.shape (h_q, d_q)"]
    D -- Yes --> F{qkv_format?}
    F -- thd --> G["out = tensor.new_zeros(mla_out_shape)\n(h_q, d_v) ✅"]
    F -- bshd/sbhd --> H["out = correction_init(...)\n.view(mla_out_shape)\n(h_q, d_v) ✅"]
    G --> I["ctx.mla_out_shape saved"]
    H --> I
    I --> J["backward()"]
    J --> K{enable_mla?}
    K -- Yes --> L["out.view(*ctx.mla_out_shape)\ndout.view(*ctx.mla_out_shape)\n(h_q, d_v) ✅"]
    K -- No --> M["out/dout.view(*q.shape)"]
    style G fill:#d4edda
    style H fill:#d4edda
    style L fill:#d4edda
Loading

Reviews (1): Last reviewed commit: "Fix CP crash with GQA + asymmetric KV he..." | Re-trigger Greptile

k_shape = k.shape
k_numel = k.numel()
v_shape = v.shape
mla_out_shape = (*q.shape[:-1], v.shape[-1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 mla_out_shape computed before the causal reshape for THD format

mla_out_shape = (*q.shape[:-1], v.shape[-1]) is placed at line 1564, which is after the causal reshape for bshd/sbhd (lines 1449-1455), so those formats get the correct 5-D layout. However, in the thd path there is no causal reshape — q retains its original [t, h_q, d_q] shape in all cases — so mla_out_shape computes correctly as (t, h_q, d_v) there too.

This is actually fine and the comment in the PR description ("Computed after the causal reshape so it parallels v_shape's 5D layout when causal") is accurate. Noting this for reviewers: the placement is intentional and correct, no action needed.

@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Apr 20, 2026
@cyanguwa
Copy link
Copy Markdown
Collaborator

Thanks for raising this issue. I believe this has been incidentally fixed by #2719.

out = torch.zeros_like(out_per_step[0]).view(o_shape)

@cyanguwa cyanguwa closed this Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Context parallel crashes with asymmetric K/V head dims (GQA + enable_mla path)

3 participants