Skip to content

Commit 3ecfe6b

Browse files
drisspgpytorchmergebot
authored andcommitted
[Submodule] Turning flash-attention integration into 3rd party submod (pytorch#144120) (pytorch#146372)
Summary: # Summary ### Sticky points Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC ## Dependencies - Flash PR: Dao-AILab/flash-attention#1419 ### Other Points - The BC linter is complaining about losing generate.py and its functions which is not real BC surface cc albanD imported-using-ghimport Test Plan: Imported from OSS Building in dev `buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-output ` I and Nming the .so I do see that the flash symbols are correctly named: ``` 0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const 0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const 0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const 0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const 0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const ``` Reviewed By: vkuzo Differential Revision: D68502879 Pulled By: drisspg Pull Request resolved: pytorch#146372 Approved by: https://github.com/jbschlosser
1 parent 276dfe8 commit 3ecfe6b

File tree

75 files changed

+186
-5749
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+186
-5749
lines changed

CMakeLists.txt

-5
Original file line numberDiff line numberDiff line change
@@ -872,11 +872,6 @@ cmake_dependent_option(
872872
"USE_CUDA OR USE_ROCM;NOT MSVC"
873873
OFF)
874874

875-
# We are currenlty not using alibi attention for Flash So we disable this
876-
# feature by default We dont currently document this feature because we don't
877-
# Suspect users building from source will need this
878-
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)
879-
880875
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
881876
# Eff Attention won't
882877
cmake_dependent_option(

aten/src/ATen/CMakeLists.txt

+6-3
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,12 @@ file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp")
164164
file(GLOB native_utils_cpp "native/utils/*.cpp")
165165

166166
# flash_attention sources
167-
file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
168-
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
169-
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
167+
file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu)
168+
# Flash attention C++ sources
169+
file(GLOB flash_attention_cuda_cpp
170+
"${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp"
171+
"native/transformers/cuda/flash_attn/flash_api.cpp"
172+
)
170173

171174
# flash_attention hip sources
172175
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")

aten/src/ATen/native/native_functions.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -14852,7 +14852,7 @@
1485214852
MPS: _scaled_dot_product_attention_math_mps
1485314853
tags: nondeterministic_seeded
1485414854

14855-
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14855+
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
1485614856
dispatch:
1485714857
CUDA: _scaled_dot_product_flash_attention_cuda
1485814858
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
@@ -14909,13 +14909,13 @@
1490914909
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
1491014910
tags: nondeterministic_seeded
1491114911

14912-
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14912+
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
1491314913
variants: function
1491414914
dispatch:
1491514915
CUDA: _flash_attention_forward
1491614916
tags: nondeterministic_seeded
1491714917

14918-
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
14918+
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
1491914919
device_check: NoCheck
1492014920
variants: function
1492114921
dispatch:

aten/src/ATen/native/transformers/cuda/attention.cu

+8-2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@
7070
#ifdef USE_FLASH_ATTENTION
7171
// FlashAttention Specific Imports
7272
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
73+
#if !defined(__HIP_PLATFORM_AMD__)
74+
#include <namespace_config.h>
75+
#endif
7376
#endif
7477
#ifdef USE_MEM_EFF_ATTENTION
7578
#ifndef USE_ROCM
@@ -916,6 +919,7 @@ _flash_attention_forward(
916919
std::optional<Tensor> seqused_k = _seqused_k;
917920
std::optional<at::Tensor> block_table = std::nullopt; // we are not using the block table yet
918921
std::optional<Tensor> alibi_slopes = _alibi_slopes;
922+
const float softcap = 0.0;
919923

920924
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
921925
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
@@ -939,7 +943,7 @@ _flash_attention_forward(
939943
philox_seed,
940944
philox_offset,
941945
debug_attn_mask) =
942-
pytorch_flash::mha_varlen_fwd(
946+
FLASH_NAMESPACE::mha_varlen_fwd(
943947
query,
944948
key,
945949
value,
@@ -957,6 +961,7 @@ _flash_attention_forward(
957961
is_causal,
958962
non_null_window_left,
959963
non_null_window_right,
964+
softcap,
960965
return_debug_mask,
961966
std::nullopt /*gen_*/);
962967
} else {
@@ -969,7 +974,7 @@ _flash_attention_forward(
969974
philox_seed,
970975
philox_offset,
971976
debug_attn_mask) =
972-
pytorch_flash::mha_fwd(
977+
FLASH_NAMESPACE::mha_fwd(
973978
query,
974979
key,
975980
value,
@@ -980,6 +985,7 @@ _flash_attention_forward(
980985
is_causal,
981986
non_null_window_left,
982987
non_null_window_right,
988+
softcap,
983989
return_debug_mask, /*return_softmax (this is used for testing)*/
984990
std::nullopt);
985991
}

aten/src/ATen/native/transformers/cuda/attention_backward.cu

+5-2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
9494

9595
// Currently unused args:
9696
std::optional<at::Tensor> alibi_slopes{std::nullopt};
97+
const float softcap = 0.0;
9798

9899
bool determinisitic{false};
99100
auto& ctx = at::globalContext();
@@ -111,7 +112,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
111112
// in order to determine whether we are using varlen or dense forward
112113
if (cumulative_sequence_length_q.defined()) {
113114
// Varlen forward
114-
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_varlen_bwd(
115+
auto [dQuery, dKey, dValue, dSoftmax] = FLASH_NAMESPACE::mha_varlen_bwd(
115116
contiguous_grad_out,
116117
query,
117118
key,
@@ -132,13 +133,14 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
132133
is_causal,
133134
non_null_window_left,
134135
non_null_window_right,
136+
softcap,
135137
determinisitic,
136138
philox_seed,
137139
philox_offset);
138140
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
139141
} else {
140142
// Dense forward
141-
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_bwd(
143+
auto [dQuery, dKey, dValue, dSoftmax] = FLASH_NAMESPACE::mha_bwd(
142144
contiguous_grad_out,
143145
query,
144146
key,
@@ -154,6 +156,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
154156
is_causal,
155157
non_null_window_left,
156158
non_null_window_right,
159+
softcap,
157160
determinisitic,
158161
philox_seed,
159162
philox_offset);

aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h

-74
This file was deleted.

aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h

-46
This file was deleted.

aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h

-96
This file was deleted.

0 commit comments

Comments
 (0)