-
Notifications
You must be signed in to change notification settings - Fork 738
[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism #2829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 16 commits
db44fc2
1a5ca4c
b4db9eb
7491ab6
b957725
c89173c
18e41bd
0b48746
608106d
4b95130
15af3af
5bec5b3
89b1066
55fc2cd
f499f59
2569a65
4e4212f
10e4cfc
2a49dee
34e3d62
4745f98
4be004f
c476f15
a2b0f1b
0ee22c7
dfc1472
ac38d4f
b94e175
7ebe3d9
ddaa196
fc9182f
636666f
7928bc9
1585ebb
7ecad01
d8bf5c5
cc104d3
2464f43
26e9f6f
611d876
0aae820
789ccf0
0a32185
c33cf2d
353361a
a1062d9
1d4e170
29785a0
09b01c9
a329afb
2dc5c15
5f606ae
24a95ab
b1faebb
8c44fcb
628f73c
b897900
669342a
0e926c4
ed28a8b
696ea9b
90ab1c7
a72e70b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,7 +100,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): | |
| pytest.skip("CP implementation with KV all-gather does not support bias yet!") | ||
| if qkv_format == "thd": | ||
| if cp_comm_type == "all_gather": | ||
| pytest.skip("CP implementation with KV all-gather does not support THD format yet!") | ||
| pytest.skip( | ||
| "FlashAttention does not support THD padding; use FusedAttention for" | ||
| " THD+all_gather CP." | ||
| ) | ||
| if cp_comm_type == "a2a+p2p": | ||
| pytest.skip( | ||
| "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" | ||
|
|
@@ -267,8 +270,6 @@ def test_cp_with_fused_attention( | |
| if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": | ||
| pytest.skip("THD format does not support post_scale_bias yet!") | ||
| if qkv_format == "thd": | ||
| if cp_comm_type == "all_gather": | ||
| pytest.skip("CP implementation with KV all-gather does not support THD format yet!") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A general comment - please run the CP file with "test_essential=False" offline because the essential tests may not cover everything.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Status: not claiming this broad offline request as fully closed yet. Current validation covers the focused FusedAttention THD all_gather pytest, FA3 THD all_gather bucket32k correctness, and prior cp2/cp4/cp8 AG-vs-a2a sweeps. Plan is to run the full CP file with |
||
| if cp_comm_type == "a2a+p2p": | ||
| pytest.skip( | ||
| "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe swap the words a little bit so it doesn't sounds like FlashAttention doesn't support THD, but just our CP implementation with it doesn't? (Also, THD implies padding in our terminology?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved in the current test wording. The skip message now says the CP implementation with QKVO A2A+P2P / hierarchical A2A does not support THD format, rather than implying FlashAttention itself does not support THD.