Skip to content

Commit

Permalink
optimize compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthuc2502 committed Apr 8, 2024
1 parent c726755 commit 845a327
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 332 deletions.
5 changes: 0 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,6 @@ if (WITH_CUDA)
cuda_include_directories(${THRUST_INCLUDE_DIRS})
list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${THRUST_INCLUDE_DIRS})

#if(NOT DEFINED ${CMAKE_CUDA_ARCHITECTURES})
# set(CMAKE_CUDA_ARCHITECTURES 70)
#endif()
#set(CMAKE_CUDA_COMPILER ${CUDA_HOST_COMPILER})
#add_subdirectory(third_party/cutlass)
set(CUTLASS_INCLUDE_DIRS
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/include
)
Expand Down
10 changes: 5 additions & 5 deletions include/ctranslate2/ops/flash-attention/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace flash {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {

using Element = typename Kernel_traits::Element;
Expand Down Expand Up @@ -375,7 +375,7 @@ namespace flash {

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.rp_dropout);

// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type<Element>(acc_o);
Expand Down Expand Up @@ -931,7 +931,7 @@ namespace flash {

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);
Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax);
// if (cute::thread0()) { print(lse); }

Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
Expand Down Expand Up @@ -1009,7 +1009,7 @@ namespace flash {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
Expand All @@ -1025,7 +1025,7 @@ namespace flash {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.

flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
flash::compute_attn_1rowblock<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit 845a327

Please sign in to comment.