-
Notifications
You must be signed in to change notification settings - Fork 757
[PyT] Plumbing correct bias dims from TE to cudnn, while adding support for additional bias shapes #2537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyT] Plumbing correct bias dims from TE to cudnn, while adding support for additional bias shapes #2537
Changes from 22 commits
ab542fc
c86328e
fddf0ac
d3aa7ec
4d295c4
f4f9cc6
f9f4fb8
d9547fa
7ede1fe
c20d67a
303aee7
e9f88f0
6bf73e1
ebee29b
126be03
7b0f942
f795056
0e74dcf
0acf8f8
2133bd8
89e90a5
5a25d9c
0e2a72f
f066c88
ff174a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,13 +52,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl( | |
| int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, | ||
| int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, | ||
| int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, | ||
| int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, | ||
| bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, | ||
| NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, | ||
| int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, | ||
| void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, | ||
| void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, | ||
| void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, | ||
| int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, | ||
| bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, | ||
| NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, | ||
| NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, | ||
| bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, | ||
| void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, | ||
| void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, | ||
| void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, | ||
| void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, | ||
| void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { | ||
| using namespace transformer_engine; | ||
|
|
@@ -121,6 +122,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( | |
| max_pages_per_seq_v, | ||
| bias_b, | ||
| bias_h, | ||
| bias_sq, | ||
| bias_skv, | ||
| scaling_factor, | ||
| is_training, | ||
| dropout_probability, | ||
|
|
@@ -270,10 +273,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( | |
| sdpa_options.set_alibi_mask(is_alibi); | ||
|
|
||
| if (is_bias) { | ||
| bias = mha_graph->tensor(fe::graph::Tensor_attributes() | ||
| .set_name("bias") | ||
| .set_dim({bias_b, bias_h, s_q, s_kv}) | ||
| .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); | ||
| bias = mha_graph->tensor( | ||
| fe::graph::Tensor_attributes() | ||
| .set_name("bias") | ||
| .set_dim({bias_b, bias_h, bias_sq, bias_skv}) | ||
| .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); | ||
| sdpa_options.set_bias(bias); | ||
| } | ||
|
|
||
|
|
@@ -549,16 +553,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( | |
| void fused_attn_arbitrary_seqlen_bwd_impl( | ||
| int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, | ||
| int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, | ||
| float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, | ||
| NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, | ||
| int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, | ||
| bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, | ||
| void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, | ||
| void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, | ||
| void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, | ||
| void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, | ||
| void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, | ||
| size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { | ||
| int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, | ||
| NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, | ||
| NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, | ||
| bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose, | ||
| void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, | ||
| void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, | ||
| void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, | ||
| void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, | ||
| void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, | ||
| void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { | ||
| using namespace transformer_engine; | ||
|
|
||
| bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); | ||
|
|
@@ -623,6 +627,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( | |
| 0, | ||
| bias_b, | ||
| bias_h, | ||
| bias_sq, | ||
| bias_skv, | ||
| scaling_factor, | ||
| true, | ||
| dropout_probability, | ||
|
|
@@ -812,19 +818,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl( | |
| sdpa_backward_options.set_alibi_mask(is_alibi); | ||
|
|
||
| if (is_bias) { | ||
| bias = mha_graph->tensor(fe::graph::Tensor_attributes() | ||
| .set_name("bias") | ||
| .set_dim({bias_b, bias_h, s_q, s_kv}) | ||
| .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); | ||
| dBias = mha_graph->tensor(fe::graph::Tensor_attributes() | ||
| .set_name("dBias") | ||
| .set_dim({bias_b, bias_h, s_q, s_kv}) | ||
| .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); | ||
| bias = mha_graph->tensor( | ||
| fe::graph::Tensor_attributes() | ||
| .set_name("bias") | ||
| .set_dim({bias_b, bias_h, bias_sq, bias_skv}) | ||
| .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); | ||
| dBias = mha_graph->tensor( | ||
| fe::graph::Tensor_attributes() | ||
| .set_name("dBias") | ||
| .set_dim({bias_b, bias_h, bias_sq, bias_skv}) | ||
| .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); | ||
| sdpa_backward_options.set_bias(bias); | ||
| // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] | ||
| // are not supported for dbias calculation but they are | ||
| // supported for forward bias calculation | ||
| if ((bias_b == 1) && (bias_h == h)) { | ||
| // bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation | ||
| // bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18 | ||
| if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) { | ||
| sdpa_backward_options.set_dbias(dBias); | ||
| } | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused When In practice the code path where if (is_bias) {
bias = mha_graph->tensor(/* ... */);
sdpa_backward_options.set_bias(bias);
// [1,1,1,s] is not supported for dbias as of cuDNN 9.18
if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) {
dBias = mha_graph->tensor(/* same dims */);
sdpa_backward_options.set_dbias(dBias);
}
}This would also simplify the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated to this PR, however, I think it is a good change to create the tensor as needed. |
||
|
|
@@ -975,7 +982,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( | |
|
|
||
| if (is_bias) { | ||
| variant_pack[bias] = devPtrBias; | ||
| if ((bias_b == 1) && (bias_h == h)) { | ||
| // bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation | ||
| // bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18 | ||
| if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) { | ||
| variant_pack[dBias] = devPtrdBias; | ||
| } else { | ||
| variant_pack[dBias] = nullptr; | ||
|
|
@@ -1084,10 +1093,14 @@ void fused_attn_arbitrary_seqlen_fwd( | |
| void *devPtrBias = nullptr; | ||
| size_t bias_b = 0; | ||
| size_t bias_h = 0; | ||
| size_t bias_sq = 0; | ||
| size_t bias_skv = 0; | ||
| if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { | ||
| devPtrBias = input_Bias->data.dptr; | ||
| bias_b = input_Bias->data.shape[0]; | ||
| bias_h = input_Bias->data.shape[1]; | ||
| bias_sq = input_Bias->data.shape[2]; | ||
| bias_skv = input_Bias->data.shape[3]; | ||
| } | ||
| void *devPtrSoftmaxOffset = nullptr; | ||
| if (softmax_type != NVTE_VANILLA_SOFTMAX) { | ||
|
|
@@ -1153,7 +1166,7 @@ void fused_attn_arbitrary_seqlen_fwd( | |
| if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { | ||
| Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| output_bias->data.dptr = nullptr; | ||
| output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; | ||
| output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv}; | ||
| output_bias->data.dtype = QKV_type; | ||
| } | ||
|
|
||
|
|
@@ -1198,10 +1211,10 @@ void fused_attn_arbitrary_seqlen_fwd( | |
| fused_attn_arbitrary_seqlen_fwd_impl( | ||
| batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, | ||
| max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, | ||
| page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, | ||
| return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, | ||
| window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, | ||
| devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, | ||
| page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, | ||
| is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, | ||
| softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, | ||
| devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, | ||
| devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, | ||
| devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, | ||
| &workspace_size, stream, handle); | ||
|
|
@@ -1245,11 +1258,15 @@ void fused_attn_arbitrary_seqlen_bwd( | |
| void *devPtrdBias = nullptr; | ||
| size_t bias_b = 0; | ||
| size_t bias_h = 0; | ||
| size_t bias_sq = 0; | ||
| size_t bias_skv = 0; | ||
| if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { | ||
| devPtrBias = input_Bias->data.dptr; | ||
| devPtrdBias = output_dBias->data.dptr; | ||
| bias_b = output_dBias->data.shape[0]; | ||
| bias_h = output_dBias->data.shape[1]; | ||
| bias_sq = output_dBias->data.shape[2]; | ||
| bias_skv = output_dBias->data.shape[3]; | ||
| } | ||
|
|
||
| size_t max_batch_size = 0; | ||
|
|
@@ -1292,11 +1309,11 @@ void fused_attn_arbitrary_seqlen_bwd( | |
|
|
||
| fused_attn_arbitrary_seqlen_bwd_impl( | ||
| batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, | ||
| max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, | ||
| qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, | ||
| bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, | ||
| devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, | ||
| devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, | ||
| max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, | ||
| p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, | ||
| window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, | ||
| devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, | ||
| devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, | ||
| devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), | ||
| workspace->data.dptr, &workspace_size, stream, handle); | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SWA test coverage reduced
The old
cp_1_4(nowcp_1_5) hadwindow_size=(512, 512)which tested bidirectional sliding window attention. This was changed towindow_size=(512, 0)(left-only SWA). The same applies to the oldcp_2_4(nowcp_2_6) at line 190. Was this intentional? If so, there's no longer any fused attention CP test coveringwindow_size_right > 0.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good catch by greptile as well
Seems like it was altered during testing and never reinstated :(
Done in ff174a8