From e6e8f95a7c6bd54942385538e9b756c09e27c444 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Mon, 8 Apr 2024 17:56:44 +0200 Subject: [PATCH 1/3] flash attention support --- .gitmodules | 3 + CMakeLists.txt | 51 +- include/ctranslate2/layers/attention.h | 98 +- include/ctranslate2/layers/attention_layer.h | 136 ++ include/ctranslate2/layers/flash_attention.h | 53 + include/ctranslate2/layers/layers.h | 1 + include/ctranslate2/layers/transformer.h | 21 +- include/ctranslate2/models/model.h | 8 + .../ctranslate2/ops/flash-attention/alibi.h | 74 + .../ops/flash-attention/block_info.h | 46 + .../ctranslate2/ops/flash-attention/flash.h | 178 +++ .../ops/flash-attention/flash_fwd_kernel.h | 1207 +++++++++++++++++ .../flash_fwd_launch_template.h | 334 +++++ .../ops/flash-attention/kernel_traits.h | 160 +++ .../ctranslate2/ops/flash-attention/mask.h | 213 +++ .../ops/flash-attention/philox.cuh | 51 + .../ctranslate2/ops/flash-attention/rotary.h | 152 +++ .../ctranslate2/ops/flash-attention/softmax.h | 185 +++ .../ops/flash-attention/static_switch.h | 108 ++ .../ctranslate2/ops/flash-attention/utils.h | 395 ++++++ include/ctranslate2/ops/flash_attention.h | 44 + include/ctranslate2/ops/ops.h | 1 + include/ctranslate2/ops/rotary.h | 6 +- include/ctranslate2/utils.h | 6 +- python/cpp/encoder.cc | 4 +- python/cpp/generator.cc | 4 +- python/cpp/replica_pool.h | 2 + python/cpp/translator.cc | 6 +- python/cpp/wav2vec2.cc | 4 +- python/cpp/whisper.cc | 4 +- .../prepare_build_environment_windows.sh | 4 +- src/cuda/utils.h | 6 + src/layers/attention.cc | 306 +---- src/layers/attention_layer.cc | 271 ++++ src/layers/flash_attention.cc | 177 +++ src/layers/transformer.cc | 60 +- src/models/model.cc | 30 +- .../flash_fwd_hdim128_bf16_sm80.cu | 10 + .../flash_fwd_hdim128_fp16_sm80.cu | 10 + .../flash_fwd_hdim160_bf16_sm80.cu | 10 + .../flash_fwd_hdim160_fp16_sm80.cu | 10 + .../flash_fwd_hdim192_bf16_sm80.cu | 10 + .../flash_fwd_hdim192_fp16_sm80.cu | 10 + .../flash_fwd_hdim224_bf16_sm80.cu | 10 + .../flash_fwd_hdim224_fp16_sm80.cu | 10 + .../flash_fwd_hdim256_bf16_sm80.cu | 10 + .../flash_fwd_hdim256_fp16_sm80.cu | 10 + .../flash_fwd_hdim32_bf16_sm80.cu | 10 + .../flash_fwd_hdim32_fp16_sm80.cu | 10 + .../flash_fwd_hdim64_bf16_sm80.cu | 10 + .../flash_fwd_hdim64_fp16_sm80.cu | 10 + .../flash_fwd_hdim96_bf16_sm80.cu | 10 + .../flash_fwd_hdim96_fp16_sm80.cu | 10 + .../flash_fwd_split_hdim128_bf16_sm80.cu | 7 + .../flash_fwd_split_hdim128_fp16_sm80.cu | 7 + .../flash_fwd_split_hdim160_bf16_sm80.cu | 7 + .../flash_fwd_split_hdim160_fp16_sm80.cu | 7 + .../flash_fwd_split_hdim192_bf16_sm80.cu | 7 + .../flash_fwd_split_hdim192_fp16_sm80.cu | 7 + .../flash_fwd_split_hdim224_bf16_sm80.cu | 7 + .../flash_fwd_split_hdim224_fp16_sm80.cu | 7 + .../flash_fwd_split_hdim256_bf16_sm80.cu | 7 + .../flash_fwd_split_hdim256_fp16_sm80.cu | 7 + .../flash_fwd_split_hdim32_bf16_sm80.cu | 7 + .../flash_fwd_split_hdim32_fp16_sm80.cu | 7 + .../flash_fwd_split_hdim64_bf16_sm80.cu | 7 + .../flash_fwd_split_hdim64_fp16_sm80.cu | 7 + .../flash_fwd_split_hdim96_bf16_sm80.cu | 7 + .../flash_fwd_split_hdim96_fp16_sm80.cu | 7 + src/ops/flash_attention.cc | 31 + src/ops/flash_attention_cpu.cc | 24 + src/ops/flash_attention_gpu.cu | 367 +++++ src/ops/rotary.cc | 5 +- src/ops/rotary_cpu.cc | 8 +- src/ops/rotary_gpu.cu | 19 +- 75 files changed, 4726 insertions(+), 409 deletions(-) create mode 100644 include/ctranslate2/layers/attention_layer.h create mode 100644 include/ctranslate2/layers/flash_attention.h create mode 100644 include/ctranslate2/ops/flash-attention/alibi.h create mode 100644 include/ctranslate2/ops/flash-attention/block_info.h create mode 100644 include/ctranslate2/ops/flash-attention/flash.h create mode 100644 include/ctranslate2/ops/flash-attention/flash_fwd_kernel.h create mode 100644 include/ctranslate2/ops/flash-attention/flash_fwd_launch_template.h create mode 100644 include/ctranslate2/ops/flash-attention/kernel_traits.h create mode 100644 include/ctranslate2/ops/flash-attention/mask.h create mode 100644 include/ctranslate2/ops/flash-attention/philox.cuh create mode 100644 include/ctranslate2/ops/flash-attention/rotary.h create mode 100644 include/ctranslate2/ops/flash-attention/softmax.h create mode 100644 include/ctranslate2/ops/flash-attention/static_switch.h create mode 100644 include/ctranslate2/ops/flash-attention/utils.h create mode 100644 include/ctranslate2/ops/flash_attention.h create mode 100644 src/layers/attention_layer.cc create mode 100644 src/layers/flash_attention.cc create mode 100644 src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu create mode 100644 src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu create mode 100644 src/ops/flash_attention.cc create mode 100644 src/ops/flash_attention_cpu.cc create mode 100644 src/ops/flash_attention_gpu.cu diff --git a/.gitmodules b/.gitmodules index cbd99fc96..3e8584488 100644 --- a/.gitmodules +++ b/.gitmodules @@ -16,3 +16,6 @@ [submodule "third_party/ruy"] path = third_party/ruy url = https://github.com/google/ruy.git +[submodule "third_party/cutlass"] + path = third_party/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/CMakeLists.txt b/CMakeLists.txt index a32a45fe7..bd37cc510 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,7 +113,9 @@ set(SOURCES src/env.cc src/filesystem.cc src/generator.cc + src/layers/attention_layer.cc src/layers/attention.cc + src/layers/flash_attention.cc src/layers/common.cc src/layers/decoder.cc src/layers/transformer.cc @@ -141,6 +143,8 @@ set(SOURCES src/ops/cos.cc src/ops/dequantize.cc src/ops/dequantize_cpu.cc + src/ops/flash_attention.cc + src/ops/flash_attention_cpu.cc src/ops/gather.cc src/ops/gather_cpu.cc src/ops/gelu.cc @@ -447,9 +451,8 @@ if (WITH_CUDA) else() list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=/MT$<$:d>") endif() - else() - list(APPEND CUDA_NVCC_FLAGS "-std=c++17") endif() + list(APPEND CUDA_NVCC_FLAGS "-std=c++17") if(OpenMP_CXX_FOUND) list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=${OpenMP_CXX_FLAGS}") endif() @@ -469,6 +472,11 @@ if (WITH_CUDA) list(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS}) set(CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) + # flags for flash attention + list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") + list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda") + list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") + message(STATUS "NVCC host compiler: ${CUDA_HOST_COMPILER}") message(STATUS "NVCC compilation flags: ${CUDA_NVCC_FLAGS}") @@ -482,6 +490,12 @@ if (WITH_CUDA) cuda_include_directories(${THRUST_INCLUDE_DIRS}) list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${THRUST_INCLUDE_DIRS}) + set(CUTLASS_INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/include + ) + cuda_include_directories(${CUTLASS_INCLUDE_DIRS}) + list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIRS}) + if(WITH_CUDNN) # Find cuDNN includes. find_path(CUDNN_INCLUDE_DIR NAMES cudnn.h HINTS ${CUDA_TOOLKIT_ROOT_DIR}/include) @@ -531,6 +545,7 @@ if (WITH_CUDA) src/ops/concat_split_slide_gpu.cu src/ops/conv1d_gpu.cu src/ops/dequantize_gpu.cu + src/ops/flash_attention_gpu.cu src/ops/gather_gpu.cu src/ops/gumbel_max_gpu.cu src/ops/layer_norm_gpu.cu @@ -544,6 +559,38 @@ if (WITH_CUDA) src/ops/topp_mask_gpu.cu src/ops/quantize_gpu.cu src/ops/nccl_ops_gpu.cu + src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu ) elseif(WITH_CUDNN) message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON") diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index d2deb5e03..87b21f725 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -1,6 +1,6 @@ #pragma once -#include "ctranslate2/layers/common.h" +#include "ctranslate2/layers/attention_layer.h" #include "ctranslate2/padder.h" namespace ctranslate2 { @@ -13,7 +13,7 @@ namespace ctranslate2 { class RotaryEmbeddings; class Alibi; - class MultiHeadAttention : public Layer + class MultiHeadAttention : public AttentionLayer { public: MultiHeadAttention(const models::Model& model, @@ -25,7 +25,7 @@ namespace ctranslate2 { Alibi* alibi = nullptr); DataType output_type() const override; dim_t output_size() const override; - void operator()(const StorageView& queries, + virtual void operator()(const StorageView& queries, const StorageView& values, const StorageView* values_lengths, StorageView& output, @@ -36,95 +36,27 @@ namespace ctranslate2 { const Padder* values_padder = nullptr, bool return_normalized_attention = true, StorageView* position_bias = nullptr, - dim_t offset = 0) const; + dim_t offset = 0) const override; - bool has_positional_embeddings() const { - return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi; + virtual bool has_positional_embeddings() const override { + return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi; } - - bool multi_query() const { - return _multi_query; - } - - static StorageView prepare_length_mask(const StorageView& lengths, - const dim_t num_heads, - const dim_t num_queries, - const bool mask_future = false, - const bool multi_query = false); - private: - const bool _tensor_parallel; - const dim_t _num_heads; - const bool _self_attention; - const bool _is_decoder; - const std::vector _linear; - const dim_t _d_model; - const dim_t _d_head; - const bool _pre_norm; - const std::unique_ptr _layer_norm; - const std::unique_ptr _rotary_embeddings; - Alibi* _alibi; + static void split_heads(StorageView& x, + dim_t num_heads, + const Padder* padder = nullptr, + dim_t beam_size = 1); + + static void combine_heads(StorageView& x, + dim_t num_heads, + const Padder* padder = nullptr, + dim_t beam_size = 1); const StorageView* _relative_attention_bias; const StorageView* _relative_position_keys; const StorageView* _relative_position_values; dim_t _maximum_relative_position; - const float _queries_scale; - const bool _multi_query; - const dim_t _num_heads_kv; const bool _merge_time_and_head_dims; const dim_t _cache_time_dim; - const dim_t _sliding_window; - }; - - enum class RotaryScalingType { - None = -1, - Linear, }; - - class RotaryEmbeddings { - public: - RotaryEmbeddings(const dim_t dim = 0, - const bool interleave = true, - const RotaryScalingType scaling_type = RotaryScalingType::None, - const float scaling_factor = 1, - const float base = 10000, - const dim_t num_initial_positions = 2048); - - void apply(StorageView& x, const dim_t offset = 0); - - private: - void initialize(const dim_t num_positions, - const dim_t dim, - const Device device, - const DataType dtype); - - const dim_t _dim; - const bool _interleave; - const RotaryScalingType _scaling_type; - const float _scaling_factor; - const float _base; - const dim_t _num_initial_positions; - const ops::Rotary _rotary_op; - - StorageView _sin; - StorageView _cos; - }; - - - class Alibi { - public: - Alibi(const bool use_positive_positions = false, const bool scale_alibi = false, const dim_t num_initial_positions = 2048); - - void apply(StorageView& x, const float scale = 1); - - private: - const bool _use_positive_positions; - const dim_t _num_initial_positions; - const bool _scale_alibi; - const ops::AlibiAdd _alibi_op; - - StorageView _alibi; - }; - } } diff --git a/include/ctranslate2/layers/attention_layer.h b/include/ctranslate2/layers/attention_layer.h new file mode 100644 index 000000000..daa9206c5 --- /dev/null +++ b/include/ctranslate2/layers/attention_layer.h @@ -0,0 +1,136 @@ +#pragma once + +#include "ctranslate2/layers/common.h" +#include "ctranslate2/padder.h" + +namespace ctranslate2 { + namespace layers { + StorageView make_relative_positions(dim_t queries_length, + dim_t keys_length, + dim_t max_position); + + class RotaryEmbeddings; + class Alibi; + + class AttentionLayer : public Layer + { + public: + AttentionLayer(const models::Model& model, + const std::string& scope, + dim_t num_heads, + bool self_attention, + bool pre_norm = true, + bool is_decoder = false, + Alibi* alibi = nullptr, + bool is_flash_attn = false); + virtual ~AttentionLayer() {}; + DataType output_type() const override; + dim_t output_size() const override; + virtual void operator()(const StorageView& queries, + const StorageView& values, + const StorageView* values_lengths, + StorageView& output, + StorageView* cached_keys = nullptr, + StorageView* cached_values = nullptr, + StorageView* attention = nullptr, + const Padder* queries_padder = nullptr, + const Padder* values_padder = nullptr, + bool return_normalized_attention = true, + StorageView* position_bias = nullptr, + dim_t offset = 0) const = 0; + + virtual bool has_positional_embeddings() const = 0; + + bool multi_query() const { + return _multi_query; + } + + static StorageView prepare_length_mask(const StorageView& lengths, + const dim_t num_heads, + const dim_t num_queries, + const bool mask_future = false, + const bool multi_query = false); + + protected: + const bool _tensor_parallel; + const dim_t _num_heads; + const bool _self_attention; + const bool _is_decoder; + const std::vector _linear; + const dim_t _d_model; + const dim_t _d_head; + const bool _pre_norm; + const std::unique_ptr _layer_norm; + const std::unique_ptr _rotary_embeddings; + Alibi* _alibi; + const float _queries_scale; + const bool _multi_query; + const dim_t _num_heads_kv; + const dim_t _sliding_window; + }; + + enum class RotaryScalingType { + None = -1, + Linear, + }; + + class RotaryEmbeddings { + public: + RotaryEmbeddings(const dim_t dim = 0, + const bool interleave = true, + const RotaryScalingType scaling_type = RotaryScalingType::None, + const float scaling_factor = 1, + const float base = 10000, + const dim_t num_initial_positions = 2048, + const bool transpose = true); + + void apply(StorageView& x, const dim_t offset = 0, bool apply = true); + + StorageView& get_cos() { + return _cos; + } + + StorageView& get_sin() { + return _sin; + } + + bool get_interleave() const { + return _interleave; + } + + private: + void initialize(const dim_t num_positions, + const dim_t dim, + const Device device, + const DataType dtype); + + const dim_t _dim; + const bool _interleave; + const RotaryScalingType _scaling_type; + const float _scaling_factor; + const float _base; + const dim_t _num_initial_positions; + const ops::Rotary _rotary_op; + const bool _transpose; + + StorageView _sin; + StorageView _cos; + }; + + + class Alibi { + public: + Alibi(const bool use_positive_positions = false, const bool scale_alibi = false, const dim_t num_initial_positions = 2048); + + void apply(StorageView& x, const float scale = 1); + + private: + const bool _use_positive_positions; + const dim_t _num_initial_positions; + const bool _scale_alibi; + const ops::AlibiAdd _alibi_op; + + StorageView _alibi; + }; + } +} diff --git a/include/ctranslate2/layers/flash_attention.h b/include/ctranslate2/layers/flash_attention.h new file mode 100644 index 000000000..315670f33 --- /dev/null +++ b/include/ctranslate2/layers/flash_attention.h @@ -0,0 +1,53 @@ +#pragma once + +#include "ctranslate2/layers/attention_layer.h" + +namespace ctranslate2 { + namespace layers { + + class RotaryEmbeddings; + class Alibi; + + class FlashMultiHeadAttention : public AttentionLayer + { + public: + FlashMultiHeadAttention(const models::Model& model, + const std::string& scope, + dim_t num_heads, + bool self_attention, + bool pre_norm = true, + bool is_decoder = false, + Alibi* alibi = nullptr); + void operator()(const StorageView& queries, + const StorageView& values, + const StorageView* values_lengths, + StorageView& output, + StorageView* cached_keys = nullptr, + StorageView* cached_values = nullptr, + StorageView* attention = nullptr, + const Padder* queries_padder = nullptr, + const Padder* values_padder = nullptr, + bool return_normalized_attention = true, + StorageView* position_bias = nullptr, + dim_t offset = 0) const override; + + virtual bool has_positional_embeddings() const override { + return _rotary_embeddings || _alibi; + } + + private: + static void split_heads(StorageView& x, + dim_t num_heads, + const Padder* padder = nullptr, + dim_t beam_size = 1); + + static void combine_heads(StorageView& x, + dim_t num_heads, + const Padder* padder = nullptr, + dim_t beam_size = 1); + + const dim_t _cache_time_dim; + static constexpr dim_t _offset_free_space{100}; + }; + } +} diff --git a/include/ctranslate2/layers/layers.h b/include/ctranslate2/layers/layers.h index 3ca357855..d1d6dd57f 100644 --- a/include/ctranslate2/layers/layers.h +++ b/include/ctranslate2/layers/layers.h @@ -1,5 +1,6 @@ #pragma once #include "attention.h" +#include "flash_attention.h" #include "common.h" #include "transformer.h" diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index a7183a30d..95576ef0f 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -1,6 +1,7 @@ #pragma once #include "ctranslate2/layers/attention.h" +#include "ctranslate2/layers/flash_attention.h" #include "ctranslate2/layers/common.h" #include "ctranslate2/layers/decoder.h" #include "ctranslate2/layers/encoder.h" @@ -44,7 +45,8 @@ namespace ctranslate2 { const std::string& scope, const dim_t num_heads, const bool pre_norm = true, - const ops::ActivationType activation_type = ops::ActivationType::ReLU); + const ops::ActivationType activation_type = ops::ActivationType::ReLU, + bool use_flash_attention = false); void operator()(const StorageView& input, const StorageView* lengths, @@ -60,12 +62,12 @@ namespace ctranslate2 { return _ff.output_size(); } - const MultiHeadAttention& get_self_attention() const { - return _self_attention; + const AttentionLayer& get_self_attention() const { + return *_self_attention; } private: - const MultiHeadAttention _self_attention; + std::unique_ptr _self_attention; const FeedForwardNetwork _ff; }; @@ -77,6 +79,7 @@ namespace ctranslate2 { const dim_t num_heads, const bool pre_norm = true, const ops::ActivationType activation_type = ops::ActivationType::ReLU, + const bool use_flash_attention = true, Alibi* alibi = nullptr); void operator()(const StorageView& input, @@ -107,16 +110,16 @@ namespace ctranslate2 { return bool(_encoder_attention); } - const MultiHeadAttention& get_self_attention() const { - return _self_attention; + const AttentionLayer& get_self_attention() const { + return *_self_attention; } private: - const MultiHeadAttention _self_attention; + const std::unique_ptr _self_attention; const std::unique_ptr _shared_layer_norm; const std::unique_ptr _input_layer_norm; const std::unique_ptr _post_attention_layer_norm; - const std::unique_ptr _encoder_attention; + const std::unique_ptr _encoder_attention; const FeedForwardNetwork _ff; }; @@ -148,6 +151,7 @@ namespace ctranslate2 { const ComputeType _compute_type; const std::unique_ptr _layernorm_embedding; const std::unique_ptr _output_norm; + const bool _use_flash_attention; const std::vector> _layers; const std::unique_ptr _position_encoder; const bool _tensor_parallel; @@ -206,6 +210,7 @@ namespace ctranslate2 { const std::unique_ptr _project_in; const std::unique_ptr _project_out; const std::unique_ptr _alibi; + const bool _use_flash_attention; const std::vector> _layers; const std::unique_ptr _position_encoder; const bool _with_encoder_attention; diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 1bd7a4c14..babba1455 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -27,11 +27,13 @@ namespace ctranslate2 { Device device = Device::CPU, int device_index = 0, ComputeType compute_type = ComputeType::DEFAULT, + bool use_flash_attention = false, bool tensor_parallel = false); static std::shared_ptr load(ModelReader& model_reader, Device device = Device::CPU, int device_index = 0, ComputeType compute_type = ComputeType::DEFAULT, + bool use_flash_attention = false, bool tensor_parallel = false); virtual std::unique_ptr as_sequence_to_sequence() const; @@ -84,6 +86,10 @@ namespace ctranslate2 { return _tensor_parallel; } + bool use_flash_attention() const { + return _use_flash_attention; + } + virtual bool use_global_int16_scale() const { return true; } @@ -169,6 +175,7 @@ namespace ctranslate2 { ComputeType _effective_compute_type = ComputeType::DEFAULT; dim_t _preferred_size_multiple = 1; std::unordered_map> _variable_index; + bool _use_flash_attention = false; bool _tensor_parallel = false; }; @@ -198,6 +205,7 @@ namespace ctranslate2 { std::vector device_indices = {0}; size_t num_replicas_per_device = 1; ComputeType compute_type = ComputeType::DEFAULT; + bool use_flash_attention = false; bool tensor_parallel = false; }; diff --git a/include/ctranslate2/ops/flash-attention/alibi.h b/include/ctranslate2/ops/flash-attention/alibi.h new file mode 100644 index 000000000..80d297fc9 --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/alibi.h @@ -0,0 +1,74 @@ +#include + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Alibi { + + const float alibi_slope; + const int max_seqlen_k, max_seqlen_q; + + __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) + : alibi_slope(alibi_slope) + , max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) { + }; + + + template + __forceinline__ __device__ void apply_alibi(Tensor &tensor, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + } + } + } else { // Bias depends on both row_idx and col_idx + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + } + } + } + } + +}; + +} // namespace flash diff --git a/include/ctranslate2/ops/flash-attention/block_info.h b/include/ctranslate2/ops/flash-attention/block_info.h new file mode 100644 index 000000000..3a23a1e1f --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/block_info.h @@ -0,0 +1,46 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + + template + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/include/ctranslate2/ops/flash-attention/flash.h b/include/ctranslate2/ops/flash-attention/flash.h new file mode 100644 index 000000000..c35b0e1a4 --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/flash.h @@ -0,0 +1,178 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + + int *__restrict__ blockmask; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Scale factor of 1 / (1 - p_dropout). + // Local window size + int window_size_left, window_size_right; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/include/ctranslate2/ops/flash-attention/flash_fwd_kernel.h b/include/ctranslate2/ops/flash-attention/flash_fwd_kernel.h new file mode 100644 index 000000000..4bff64f07 --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/flash_fwd_kernel.h @@ -0,0 +1,1207 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "mask.h" +#include "rotary.h" + +namespace flash { + + using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + template + inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +#pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor tSgS = thr_mma.partition_C(gP); + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + cute::copy(rP, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(acc_s); + if (Return_softmax) { + cute::copy(rP, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.rp_dropout); + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + template + inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyOaccum, + typename Kernel_traits::GmemTiledCopyO + >; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +#pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + + } + } + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + if (n_block > n_block_copy_min) { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; + const int offset_diff = block_table_offset_next - block_table_offset_cur; + tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; + tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + } + } + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; + } + + // Read Q from gmem to smem, optionally apply rotary embedding. + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + // if (cute::thread0()) { print(lse); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { __syncthreads(); } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + // __syncthreads(); + // if (cute::thread0()) { print(tOgOaccum); } + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + template + inline __device__ void compute_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + template + inline __device__ void compute_attn_splitkv(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + template + inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { sLSE[row][col] = lse; } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } + // Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM + ); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print_tensor(tOrO); } + + Tensor rO = flash::convert_type(tOrO); + // Write to gO +#pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + + head_idx * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } + } + +} // namespace flash diff --git a/include/ctranslate2/ops/flash-attention/flash_fwd_launch_template.h b/include/ctranslate2/ops/flash-attention/flash_fwd_launch_template.h new file mode 100644 index 000000000..6f14be1bd --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/flash_fwd_launch_template.h @@ -0,0 +1,334 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_kernel.h" +#include "ctranslate2/devices.h" +#include "cuda/utils.h" + +template +#if __CUDA_ARCH__ >= 800 +__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + flash::compute_attn(params); +} +#else +__global__ void flash_fwd_kernel(const Flash_fwd_params params) { +} +#endif + + +template +#if __CUDA_ARCH__ >= 800 +__global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) { + flash::compute_attn_splitkv(params); +} +#else +__global__ void flash_fwd_splitkv_kernel(const Flash_fwd_params params) { +} +#endif + +template +#if __CUDA_ARCH__ >= 800 +__global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) { + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +} +#else +__global__ void flash_fwd_splitkv_combine_kernel(const Flash_fwd_params params) { +} +#endif + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + CUDA_CHECK(cudaGetLastError()); + }); + }); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + CUDA_CHECK(cudaGetLastError()) + }); + }); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + CUDA_CHECK(cudaGetLastError()); + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int kBlockM = 64; // Fixed for all head dimensions + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + // Also for headdim 160 with block size 64 x 128 after the rotary addition. + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 32; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 96; + int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); + auto dprops = ctranslate2::cuda::get_device_properties(device_id); + bool is_sm8x = dprops.major == 8 && dprops.minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); + auto dprops = ctranslate2::cuda::get_device_properties(device_id); + bool is_sm8x = dprops.major == 8 && dprops.minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 160; + int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); + auto dprops = ctranslate2::cuda::get_device_properties(device_id); + bool is_sm8x = dprops.major == 8 && dprops.minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 192; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 224; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + CUDA_CHECK(status_); + } + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + CUDA_CHECK(status_); + } + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); +} diff --git a/include/ctranslate2/ops/flash-attention/kernel_traits.h b/include/ctranslate2/ops/flash-attention/kernel_traits.h new file mode 100644 index 000000000..00e772d0b --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/kernel_traits.h @@ -0,0 +1,160 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; +#else + using MMA_Atom_Arch = MMA_Atom; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + Tile, _16, _16>>; + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/ctranslate2/ops/flash-attention/mask.h b/include/ctranslate2/ops/flash-attention/mask.h new file mode 100644 index 000000000..3d9b42985 --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/mask.h @@ -0,0 +1,213 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +namespace flash { + +using namespace cute; + +template +__forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +__forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +__forceinline__ __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +__forceinline__ __device__ void apply_mask_causal_w_idx( + Tensor &tensor, Tensor const &idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) +{ + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); + #pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +struct Mask { + + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope=0.f) + : max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor &tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } + } + } + } + } + } else { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; + +}; + +} // namespace flash diff --git a/include/ctranslate2/ops/flash-attention/philox.cuh b/include/ctranslate2/ops/flash-attention/philox.cuh new file mode 100644 index 000000000..cd7e4d2fa --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/philox.cuh @@ -0,0 +1,51 @@ +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h +#pragma once +// Philox CUDA. + +namespace flash { + +struct ull2 { + unsigned long long x; + unsigned long long y; +}; + +__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; +} + +__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { + constexpr unsigned long kPhiloxSA = 0xD2511F53; + constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; +} + +__forceinline__ __device__ uint4 philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + constexpr unsigned long kPhilox10A = 0x9E3779B9; + constexpr unsigned long kPhilox10B = 0xBB67AE85; + uint2 key = reinterpret_cast(seed); + uint4 counter; + ull2 *tmp = reinterpret_cast(&counter); + tmp->x = offset; + tmp->y = subsequence; + #pragma unroll + for (int i = 0; i < 6; i++) { + counter = philox_single_round(counter, key); + key.x += (kPhilox10A); + key.y += (kPhilox10B); + } + uint4 output = philox_single_round(counter, key); + return output; +} + +} // namespace flash diff --git a/include/ctranslate2/ops/flash-attention/rotary.h b/include/ctranslate2/ops/flash-attention/rotary.h new file mode 100644 index 000000000..dc2825be7 --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/rotary.h @@ -0,0 +1,152 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "utils.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/include/ctranslate2/ops/flash-attention/softmax.h b/include/ctranslate2/ops/flash-attention/softmax.h new file mode 100644 index 000000000..0af500c56 --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/softmax.h @@ -0,0 +1,185 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "philox.cuh" +#include "utils.h" + +#ifndef M_LOG2E +#define M_LOG2E 1.4426950408889634074 +#endif + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = inv_sum; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace flash diff --git a/include/ctranslate2/ops/flash-attention/static_switch.h b/include/ctranslate2/ops/flash-attention/static_switch.h new file mode 100644 index 000000000..7b38de2d0 --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/static_switch.h @@ -0,0 +1,108 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once +#include + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_DROPOUT + #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI + #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + #define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else + #define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define LOCAL_SWITCH BOOL_SWITCH +#endif + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/include/ctranslate2/ops/flash-attention/utils.h b/include/ctranslate2/ops/flash-attention/utils.h new file mode 100644 index 000000000..094053386 --- /dev/null +++ b/include/ctranslate2/ops/flash-attention/utils.h @@ -0,0 +1,395 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ uint32_t relu2(const uint32_t x); + +template<> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); + +template<> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template<> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +template +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void relu_(Tensor &tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); + #pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); + #pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, + const int max_MN=0, const int min_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/include/ctranslate2/ops/flash_attention.h b/include/ctranslate2/ops/flash_attention.h new file mode 100644 index 000000000..a05133690 --- /dev/null +++ b/include/ctranslate2/ops/flash_attention.h @@ -0,0 +1,44 @@ +#pragma once + +#include "op.h" + +namespace ctranslate2 { + namespace ops { + class FlashAttention : public Op { + public: + FlashAttention(float queries_scale, dim_t sliding_window); + + void operator()(StorageView& queries, + StorageView& keys, + StorageView& values, + StorageView& output, + StorageView* cached_keys, + StorageView* cached_values, + StorageView* attention, + bool return_normalized_attention, + StorageView* rotary_cos, + StorageView* rotary_sin, + const bool rotary_interleave, + StorageView* alibi, + dim_t offset) const; + + private: + const float _queries_scale; + const dim_t _sliding_window; + template + void compute(StorageView& queries, + StorageView& keys, + StorageView& values, + StorageView& output, + StorageView* cached_keys, + StorageView* cached_values, + StorageView* attention, + bool return_normalized_attention, + StorageView* rotary_cos, + StorageView* rotary_sin, + const bool rotary_interleave, + StorageView* alibi, + dim_t offset) const; + }; + } +} diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h index f03d0211a..ceca49450 100644 --- a/include/ctranslate2/ops/ops.h +++ b/include/ctranslate2/ops/ops.h @@ -38,3 +38,4 @@ #include "alibi_add.h" #include "slide.h" #include "nccl_ops.h" +#include "flash_attention.h" diff --git a/include/ctranslate2/ops/rotary.h b/include/ctranslate2/ops/rotary.h index c0a4cf091..81413b45b 100644 --- a/include/ctranslate2/ops/rotary.h +++ b/include/ctranslate2/ops/rotary.h @@ -12,7 +12,8 @@ namespace ctranslate2 { void operator()(const StorageView& input, const StorageView& sin, const StorageView& cos, - StorageView& output) const; + StorageView& output, + bool is_transpose=true) const; private: const dim_t _ndims; @@ -22,7 +23,8 @@ namespace ctranslate2 { void compute(const StorageView& input, const StorageView& sin, const StorageView& cos, - StorageView& output) const; + StorageView& output, + bool is_transpose) const; }; } diff --git a/include/ctranslate2/utils.h b/include/ctranslate2/utils.h index 23c58cb82..f1c122d7f 100644 --- a/include/ctranslate2/utils.h +++ b/include/ctranslate2/utils.h @@ -95,5 +95,9 @@ namespace ctranslate2 { #define THROW_INVALID_ARGUMENT(MESSAGE) THROW_EXCEPTION(std::invalid_argument, MESSAGE) #define SAFE_DIVIDE(x, y) ((y != 0 && (x % y == 0)) ? (x / y) : (throw std::runtime_error("Division has a remainder," \ "Model can't be ran with the tensor parallel mode in " + std::to_string(y) + " nodes"))) - +#define ERROR_CHECK(ans, message) \ + { \ + if (!ans) \ + THROW_RUNTIME_ERROR(message); \ + } } diff --git a/python/cpp/encoder.cc b/python/cpp/encoder.cc index 9a50923ac..cd1ebec91 100644 --- a/python/cpp/encoder.cc +++ b/python/cpp/encoder.cc @@ -71,7 +71,7 @@ namespace ctranslate2 { >>> encoder.forward_batch([["▁Hello", "▁world", "!"]]) )pbdoc") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -80,6 +80,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("flash_attention")=false, py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( @@ -97,6 +98,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + flash_attention: run model with flash attention 2 for self-attention layer tensor_parallel: run model with tensor parallel mode files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index 93b1a229a..4ade4af5b 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -128,7 +128,7 @@ namespace ctranslate2 { >>> generator.generate_batch([[""]], max_length=50, sampling_topk=20) )pbdoc") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -137,6 +137,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("flash_attention")=false, py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( @@ -154,6 +155,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + flash_attention: run model with flash attention 2 for self-attention layer tensor_parallel: run model with tensor parallel mode. files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, diff --git a/python/cpp/replica_pool.h b/python/cpp/replica_pool.h index d71bf6b96..8ea52817f 100644 --- a/python/cpp/replica_pool.h +++ b/python/cpp/replica_pool.h @@ -44,6 +44,7 @@ namespace ctranslate2 { size_t inter_threads, size_t intra_threads, long max_queued_batches, + bool flash_attention, bool tensor_parallel, py::object files) : _model_loader(create_model_reader(model_path, files)) @@ -54,6 +55,7 @@ namespace ctranslate2 { _model_loader.device_indices = std::visit(DeviceIndexResolver(), device_index); _model_loader.compute_type = std::visit(ComputeTypeResolver(device), compute_type); _model_loader.num_replicas_per_device = inter_threads; + _model_loader.use_flash_attention = flash_attention; _model_loader.tensor_parallel = tensor_parallel; _pool_config.num_threads_per_replica = intra_threads; diff --git a/python/cpp/translator.cc b/python/cpp/translator.cc index 8e4a8a4be..15e529919 100644 --- a/python/cpp/translator.cc +++ b/python/cpp/translator.cc @@ -33,6 +33,7 @@ namespace ctranslate2 { size_t inter_threads, size_t intra_threads, long max_queued_batches, + bool flash_attention, bool tensor_parallel, py::object files) : ReplicaPoolHelper(model_path, @@ -42,6 +43,7 @@ namespace ctranslate2 { inter_threads, intra_threads, max_queued_batches, + flash_attention, tensor_parallel, files) , _device(_model_loader.device) @@ -380,7 +382,7 @@ namespace ctranslate2 { >>> translator.translate_batch([["▁Hello", "▁world", "!"]]) )pbdoc") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -389,6 +391,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("flash_attention")=false, py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( @@ -406,6 +409,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + flash_attention: run model with flash attention 2 for self-attention layer tensor_parallel: run model with tensor parallel mode files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, diff --git a/python/cpp/wav2vec2.cc b/python/cpp/wav2vec2.cc index 343caa158..5b882cbbd 100644 --- a/python/cpp/wav2vec2.cc +++ b/python/cpp/wav2vec2.cc @@ -27,7 +27,7 @@ namespace ctranslate2 { https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec )pbdoc") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -36,6 +36,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("flash_attention")=false, py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( @@ -53,6 +54,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + flash_attention: run model with flash attention 2 for self-attention layer tensor_parallel: run model with tensor parallel mode files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, diff --git a/python/cpp/whisper.cc b/python/cpp/whisper.cc index 47be8ece7..88a1e4c17 100644 --- a/python/cpp/whisper.cc +++ b/python/cpp/whisper.cc @@ -163,7 +163,7 @@ namespace ctranslate2 { .def_property_readonly("num_languages", &WhisperWrapper::num_languages, "Returns the number of languages supported.") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -172,6 +172,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("flash_attention")=false, py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( @@ -189,6 +190,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + flash_attention: run model with flash attention 2 for self-attention layer tensor_parallel: run model with tensor parallel mode files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, diff --git a/python/tools/prepare_build_environment_windows.sh b/python/tools/prepare_build_environment_windows.sh index b5fe03ecc..1a4d76eeb 100755 --- a/python/tools/prepare_build_environment_windows.sh +++ b/python/tools/prepare_build_environment_windows.sh @@ -26,14 +26,14 @@ curl -L -O https://github.com/oneapi-src/oneDNN/archive/refs/tags/v${ONEDNN_VERS tar xf *.tar.gz && rm *.tar.gz cd oneDNN-* cmake -DCMAKE_BUILD_TYPE=Release -DONEDNN_LIBRARY_TYPE=STATIC -DONEDNN_BUILD_EXAMPLES=OFF -DONEDNN_BUILD_TESTS=OFF -DONEDNN_ENABLE_WORKLOAD=INFERENCE -DONEDNN_ENABLE_PRIMITIVE="CONVOLUTION;REORDER" -DONEDNN_BUILD_GRAPH=OFF . -cmake --build . --config Release --target install --parallel 2 +cmake --build . --config Release --target install --parallel 6 cd .. rm -r oneDNN-* mkdir build cd build cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=$CTRANSLATE2_ROOT -DCMAKE_PREFIX_PATH="C:/Program Files (x86)/Intel/oneAPI/compiler/latest/windows/compiler/lib/intel64_win;C:/Program Files (x86)/oneDNN" -DBUILD_CLI=OFF -DWITH_DNNL=ON -DWITH_CUDA=ON -DWITH_CUDNN=ON -DCUDA_TOOLKIT_ROOT_DIR="$CUDA_ROOT" -DCUDA_DYNAMIC_LOADING=ON -DCUDA_NVCC_FLAGS="-Xfatbin=-compress-all" -DCUDA_ARCH_LIST="Common" .. -cmake --build . --config Release --target install --parallel 2 --verbose +cmake --build . --config Release --target install --parallel 6 --verbose cd .. rm -r build diff --git a/src/cuda/utils.h b/src/cuda/utils.h index 8c1c134fe..2f8c4f5ab 100644 --- a/src/cuda/utils.h +++ b/src/cuda/utils.h @@ -62,6 +62,12 @@ namespace ctranslate2 { + std::string(cudnnGetErrorString(status))); \ } +#define TENSOR_CHECK(ans, message) \ + { \ + if (!ans) \ + THROW_RUNTIME_ERROR(message); \ + } + const char* cublasGetStatusName(cublasStatus_t status); cudaStream_t get_cuda_stream(); diff --git a/src/layers/attention.cc b/src/layers/attention.cc index cf6074b2a..f340c44f9 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -106,47 +106,6 @@ namespace ctranslate2 { return values_t; } - static StorageView build_alibi(dim_t num_heads, - dim_t key_max_length, - bool use_positive_positions, - const float scale) { - const float closest_power_of_2_f = std::pow(2.f, std::floor(std::log2f(num_heads))); - const dim_t closest_power_of_2 = closest_power_of_2_f; - - const float base = std::pow(2.f, -std::pow(2.f, -(std::log2f(closest_power_of_2_f) - 3.f))); - - std::vector slopes; - slopes.reserve(closest_power_of_2); - for (dim_t power = 1; power <= closest_power_of_2; ++power) - slopes.emplace_back(std::pow(base, float(power))); - - if (closest_power_of_2 != num_heads) { - const float extra_base = ( - std::pow(2.f, -std::pow(2.f, -(std::log2f(2 * closest_power_of_2_f) - 3.f)))); - const dim_t num_remaining_heads = std::min( - closest_power_of_2, num_heads - closest_power_of_2); - - for (dim_t power = 1; power <= 2 * num_remaining_heads; power += 2) - slopes.emplace_back(std::pow(extra_base, float(power))); - } - - std::vector positions(key_max_length); - std::iota(positions.begin(), - positions.end(), - use_positive_positions ? 0 : -key_max_length + 1); - - StorageView alibi({1, num_heads, 1, key_max_length}); - - for (dim_t h = 0; h < num_heads; ++h) { - primitives::mul(slopes[h] * scale, - positions.data(), - alibi.index({0, h, 0, 0}), - key_max_length); - } - - return alibi; - } - static void matmul_with_relative_representations(const ops::MatMul& matmul_op, const StorageView& a, const StorageView& b, @@ -281,30 +240,7 @@ namespace ctranslate2 { save_attention(*attention, std::move(attn), beam_size); } - static void split_heads(StorageView& x, - dim_t num_heads, - const Padder* padder = nullptr, - dim_t beam_size = 1) { - if (padder) - padder->add_padding(x); - - if (beam_size > 1) - x.reshape({x.dim(0) / beam_size, beam_size, x.dim(2)}); - - // x has shape [batch_size, time, depth] - const dim_t batch_size = x.dim(0); - const dim_t time = x.dim(1); - const dim_t head_dim = x.dim(2) / num_heads; - if (time == 1) { - x.reshape({batch_size, num_heads, 1, head_dim}); - } else { - x.reshape({batch_size, time, num_heads, head_dim}); - StorageView y(x.device(), x.dtype()); - transpose_op(x, y); - x = std::move(y); - } - } static void replicate_heads(StorageView& x, dim_t repeats) { x.expand_dims(2); @@ -312,66 +248,6 @@ namespace ctranslate2 { x.reshape({x.dim(0), x.dim(1) * x.dim(2), x.dim(3), x.dim(4)}); } - static void combine_heads(StorageView& x, - dim_t num_heads, - const Padder* padder = nullptr, - dim_t beam_size = 1) { - // x has shape [batch_size, num_heads, time, head_dim] - const dim_t batch_size = x.dim(0); - const dim_t time = x.dim(2); - const dim_t depth = x.dim(3) * num_heads; - - if (time > 1) { - StorageView y(x.device(), x.dtype()); - transpose_op(x, y); - x = std::move(y); - } - - x.reshape({batch_size, time, depth}); - - if (beam_size > 1) - x.reshape({batch_size * beam_size, 1, depth}); - - if (padder) - padder->remove_padding(x); - } - - static std::vector make_linear_layers(const models::Model& model, - const std::string& scope, - bool self_attention) { - const dim_t num_linear_layers = self_attention ? 2 : 3; - std::vector layers; - layers.reserve(num_linear_layers); - for (dim_t i = 0; i < num_linear_layers; ++i) - if (i == (num_linear_layers - 1)) { - layers.emplace_back(model, scope + "/linear_" + std::to_string(i), nullptr, true); - } else - layers.emplace_back(model, scope + "/linear_" + std::to_string(i)); - return layers; - } - - static std::unique_ptr make_rotary_embeddings(const models::Model& model, - const std::string& scope) { - const dim_t rotary_dim = model.get_attribute_with_default(scope + "/rotary_dim", -1); - if (rotary_dim < 0) - return nullptr; - - const bool interleave = model.get_flag_with_default(scope + "/rotary_interleave", true); - const float base = model.get_attribute_with_default(scope + "/rotary_base", 10000.f); - - const auto scaling_type = model.get_enum_value( - scope + "/rotary_scaling_type", -1); - const auto scaling_factor = model.get_attribute_with_default( - scope + "/rotary_scaling_factor", 1.f); - - return std::make_unique(rotary_dim, - interleave, - scaling_type, - scaling_factor, - base); - } - - MultiHeadAttention::MultiHeadAttention(const models::Model& model, const std::string& scope, dim_t num_heads, @@ -379,35 +255,15 @@ namespace ctranslate2 { bool pre_norm, bool is_decoder, Alibi* alibi) - : _tensor_parallel(model.tensor_parallel()) - , _num_heads(_tensor_parallel ? SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()) : num_heads) - , _self_attention(self_attention) - , _is_decoder(is_decoder) - , _linear(make_linear_layers(model, scope, self_attention)) - , _d_model(_tensor_parallel ? SAFE_DIVIDE(_linear.back().output_size(), ScopedMPISetter::getNRanks()) : _linear.back().output_size()) - , _d_head(model.get_attribute_with_default(scope + "/head_dim", _d_model / _num_heads)) - , _pre_norm(pre_norm) - , _layer_norm(build_optional_layer(model, scope + "/layer_norm")) - , _rotary_embeddings(make_rotary_embeddings(model, scope)) - , _alibi(alibi) + : AttentionLayer(model, scope, num_heads, self_attention, pre_norm, is_decoder, alibi, false) , _relative_attention_bias(model.get_variable_if_exists(scope + "/relative_attention_bias")) , _relative_position_keys(model.get_variable_if_exists(scope + "/relative_position_keys")) , _relative_position_values(model.get_variable_if_exists(scope + "/relative_position_values")) - , _queries_scale(model.get_attribute_with_default( - scope + "/queries_scale", - 1.f / std::sqrt(static_cast(_d_head)))) - , _multi_query(model.get_flag_with_default(scope + "/multi_query", false)) - , _num_heads_kv(_multi_query - ? 1 - : (_tensor_parallel ? model.get_attribute_with_default(scope + "/num_heads_kv", - _num_heads * ScopedMPISetter::getNRanks()) / ScopedMPISetter::getNRanks() - : model.get_attribute_with_default(scope + "/num_heads_kv", _num_heads))) , _merge_time_and_head_dims(_multi_query && !_relative_attention_bias && !_relative_position_keys && !_relative_position_values) - , _cache_time_dim(_merge_time_and_head_dims ? 1 : 2) - , _sliding_window(model.get_attribute_with_default(scope + "/sliding_window", 0)) + ,_cache_time_dim(_merge_time_and_head_dims ? 1 : 2) { if (_relative_position_keys) _maximum_relative_position = (_relative_position_keys->dim(0) - 1) / 2; @@ -613,137 +469,53 @@ namespace ctranslate2 { } } - StorageView MultiHeadAttention::prepare_length_mask(const StorageView& lengths, - const dim_t num_heads, - const dim_t num_queries, - const bool mask_future, - const bool multi_query) { - const Device device = lengths.device(); - const dim_t batch_size = lengths.size(); - StorageView mask(lengths.dtype(), device); - - if (multi_query) - mask.resize({batch_size, num_queries, num_heads}); - else - mask.resize({batch_size, num_heads, num_queries}); - - DEVICE_DISPATCH(device, (primitives::prepare_length_mask(lengths.data(), - batch_size, - num_heads, - num_queries, - mask_future, - multi_query, - mask.data()))); - return mask; - } - + void MultiHeadAttention::split_heads(StorageView& x, + dim_t num_heads, + const Padder* padder, + dim_t beam_size) { + if (padder) + padder->add_padding(x); - RotaryEmbeddings::RotaryEmbeddings(const dim_t dim, - const bool interleave, - const RotaryScalingType scaling_type, - const float scaling_factor, - const float base, - const dim_t num_initial_positions) - : _dim(dim) - , _interleave(interleave) - , _scaling_type(scaling_type) - , _scaling_factor(scaling_factor) - , _base(base) - , _num_initial_positions(num_initial_positions) - , _rotary_op(dim, interleave) - { - } + if (beam_size > 1) + x.reshape({x.dim(0) / beam_size, beam_size, x.dim(2)}); - void RotaryEmbeddings::apply(StorageView& x, const dim_t offset) { - const Device device = x.device(); - const DataType dtype = x.dtype(); - const dim_t max_time = x.dim(-2); - const dim_t dim = _dim == 0 ? x.dim(-1) : _dim; + // x has shape [batch_size, time, depth] + const dim_t batch_size = x.dim(0); + const dim_t time = x.dim(1); + const dim_t head_dim = x.dim(2) / num_heads; - if (!_sin || offset + max_time > _sin.dim(0)) { - const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0; - const dim_t new_num_positions = std::max(offset + max_time, cur_num_positions + _num_initial_positions); - initialize(new_num_positions, dim, device, dtype); + if (time == 1) { + x.reshape({batch_size, num_heads, 1, head_dim}); + } else { + x.reshape({batch_size, time, num_heads, head_dim}); + StorageView y(x.device(), x.dtype()); + transpose_op(x, y); + x = std::move(y); } - - StorageView sin(dtype, device); - StorageView cos(dtype, device); - TYPE_DISPATCH(dtype, - { - sin.view(_sin.index({offset, 0}), {max_time, dim}); - cos.view(_cos.index({offset, 0}), {max_time, dim}); - }); - - StorageView y(dtype, device); - _rotary_op(x, sin, cos, y); - x = std::move(y); - } - - void RotaryEmbeddings::initialize(const dim_t num_positions, - const dim_t dim, - const Device device, - const DataType dtype) { - StorageView inv_freq({1, dim / 2}); - for (dim_t i = 0; i < inv_freq.size(); ++i) - inv_freq.at(i) = 1.f / std::pow(_base, float(i * 2) / float(dim)); - if (inv_freq.device() != device) - inv_freq = inv_freq.to(device); - - StorageView t({num_positions, 1}); - for (dim_t i = 0; i < t.size(); ++i) - t.at(i) = _scaling_type == RotaryScalingType::None ? i : float(i) / _scaling_factor; - if (t.device() != device) - t = t.to(device); - - StorageView freqs(device); - ops::MatMul()(t, inv_freq, freqs); - - if (_interleave) - freqs.expand_dims(-1); - - StorageView emb(device); - ops::Concat(-1)({&freqs, &freqs}, emb); - - if (_interleave) - emb.reshape({num_positions, dim}); - - StorageView sin(device); - ops::Sin()(emb, sin); - if (sin.dtype() == dtype) - _sin = std::move(sin); - else - _sin = sin.to(dtype); - - StorageView cos(device); - ops::Cos()(emb, cos); - if (cos.dtype() == dtype) - _cos = std::move(cos); - else - _cos = cos.to(dtype); } + void MultiHeadAttention::combine_heads(StorageView& x, + dim_t num_heads, + const Padder* padder, + dim_t beam_size) { + // x has shape [batch_size, num_heads, time, head_dim] + const dim_t batch_size = x.dim(0); + const dim_t time = x.dim(2); + const dim_t depth = x.dim(3) * num_heads; - Alibi::Alibi(const bool use_positive_positions, const bool scale_alibi, const dim_t num_initial_positions) - : _use_positive_positions(use_positive_positions) - , _num_initial_positions(num_initial_positions) - , _scale_alibi(scale_alibi) - , _alibi_op(use_positive_positions) - { - } + if (time > 1) { + StorageView y(x.device(), x.dtype()); + transpose_op(x, y); + x = std::move(y); + } - void Alibi::apply(StorageView& x, const float scale) { - const dim_t cur_length = _alibi ? _alibi.dim(-1) : 0; - const dim_t key_length = x.dim(-1); + x.reshape({batch_size, time, depth}); - if (key_length > cur_length) { - const dim_t num_heads = x.dim(1); - const dim_t new_length = cur_length + _num_initial_positions; - _alibi = build_alibi(num_heads, new_length, _use_positive_positions, _scale_alibi ? scale : 1); - _alibi.move_to(x.device(), x.dtype()); - } + if (beam_size > 1) + x.reshape({batch_size * beam_size, 1, depth}); - _alibi_op(x, _alibi, x); + if (padder) + padder->remove_padding(x); } - } } diff --git a/src/layers/attention_layer.cc b/src/layers/attention_layer.cc new file mode 100644 index 000000000..18b4fa16b --- /dev/null +++ b/src/layers/attention_layer.cc @@ -0,0 +1,271 @@ +#include "ctranslate2/layers/attention.h" +#include "ctranslate2/ops/split.h" + +#include +#include +#include + +#include "dispatch.h" +#include "cpu/parallel.h" +#include + +namespace ctranslate2 { + namespace layers { + static StorageView build_alibi(dim_t num_heads, + dim_t key_max_length, + bool use_positive_positions, + const float scale) { + const float closest_power_of_2_f = std::pow(2.f, std::floor(std::log2f(num_heads))); + const dim_t closest_power_of_2 = closest_power_of_2_f; + + const float base = std::pow(2.f, -std::pow(2.f, -(std::log2f(closest_power_of_2_f) - 3.f))); + + std::vector slopes; + slopes.reserve(closest_power_of_2); + for (dim_t power = 1; power <= closest_power_of_2; ++power) + slopes.emplace_back(std::pow(base, float(power))); + + if (closest_power_of_2 != num_heads) { + const float extra_base = ( + std::pow(2.f, -std::pow(2.f, -(std::log2f(2 * closest_power_of_2_f) - 3.f)))); + const dim_t num_remaining_heads = std::min( + closest_power_of_2, num_heads - closest_power_of_2); + + for (dim_t power = 1; power <= 2 * num_remaining_heads; power += 2) + slopes.emplace_back(std::pow(extra_base, float(power))); + } + + std::vector positions(key_max_length); + std::iota(positions.begin(), + positions.end(), + use_positive_positions ? 0 : -key_max_length + 1); + + StorageView alibi({1, num_heads, 1, key_max_length}); + + for (dim_t h = 0; h < num_heads; ++h) { + primitives::mul(slopes[h] * scale, + positions.data(), + alibi.index({0, h, 0, 0}), + key_max_length); + } + + return alibi; + } + + static std::vector make_linear_layers(const models::Model& model, + const std::string& scope, + bool self_attention) { + const dim_t num_linear_layers = self_attention ? 2 : 3; + std::vector layers; + layers.reserve(num_linear_layers); + for (dim_t i = 0; i < num_linear_layers; ++i) + if (i == (num_linear_layers - 1)) { + layers.emplace_back(model, scope + "/linear_" + std::to_string(i), nullptr, true); + } else + layers.emplace_back(model, scope + "/linear_" + std::to_string(i)); + return layers; + } + + static std::unique_ptr make_rotary_embeddings(const models::Model& model, + const std::string& scope, + bool transpose) { + const dim_t rotary_dim = model.get_attribute_with_default(scope + "/rotary_dim", -1); + if (rotary_dim < 0) + return nullptr; + + const bool interleave = model.get_flag_with_default(scope + "/rotary_interleave", true); + const float base = model.get_attribute_with_default(scope + "/rotary_base", 10000.f); + + const auto scaling_type = model.get_enum_value( + scope + "/rotary_scaling_type", -1); + const auto scaling_factor = model.get_attribute_with_default( + scope + "/rotary_scaling_factor", 1.f); + + return std::make_unique(rotary_dim, + interleave, + scaling_type, + scaling_factor, + base, + /*num_initial_positions*/2048, + transpose); + } + + + AttentionLayer::AttentionLayer(const models::Model& model, + const std::string& scope, + dim_t num_heads, + bool self_attention, + bool pre_norm, + bool is_decoder, + Alibi* alibi, + bool is_flash_attn) + : _tensor_parallel(model.tensor_parallel()) + , _num_heads(_tensor_parallel ? SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()) : num_heads) + , _self_attention(self_attention) + , _is_decoder(is_decoder) + , _linear(make_linear_layers(model, scope, self_attention)) + , _d_model(_tensor_parallel ? SAFE_DIVIDE(_linear.back().output_size(), ScopedMPISetter::getNRanks()) : _linear.back().output_size()) + , _d_head(model.get_attribute_with_default(scope + "/head_dim", _d_model / _num_heads)) + , _pre_norm(pre_norm) + , _layer_norm(build_optional_layer(model, scope + "/layer_norm")) + , _rotary_embeddings(make_rotary_embeddings(model, scope, !is_flash_attn)) + , _alibi(alibi) + , _queries_scale(model.get_attribute_with_default( + scope + "/queries_scale", + 1.f / std::sqrt(static_cast(_d_head)))) + , _multi_query(model.get_flag_with_default(scope + "/multi_query", false)) + , _num_heads_kv(_multi_query + ? 1 + : (_tensor_parallel ? model.get_attribute_with_default(scope + "/num_heads_kv", + _num_heads * ScopedMPISetter::getNRanks()) / ScopedMPISetter::getNRanks() + : model.get_attribute_with_default(scope + "/num_heads_kv", _num_heads))) + , _sliding_window(model.get_attribute_with_default(scope + "/sliding_window", 0)) + { + } + + DataType AttentionLayer::output_type() const { + return _linear.back().output_type(); + } + + dim_t AttentionLayer::output_size() const { + return _d_model; + } + + StorageView AttentionLayer::prepare_length_mask(const StorageView& lengths, + const dim_t num_heads, + const dim_t num_queries, + const bool mask_future, + const bool multi_query) { + const Device device = lengths.device(); + const dim_t batch_size = lengths.size(); + StorageView mask(lengths.dtype(), device); + + if (multi_query) + mask.resize({batch_size, num_queries, num_heads}); + else + mask.resize({batch_size, num_heads, num_queries}); + + DEVICE_DISPATCH(device, (primitives::prepare_length_mask(lengths.data(), + batch_size, + num_heads, + num_queries, + mask_future, + multi_query, + mask.data()))); + return mask; + } + + + RotaryEmbeddings::RotaryEmbeddings(const dim_t dim, + const bool interleave, + const RotaryScalingType scaling_type, + const float scaling_factor, + const float base, + const dim_t num_initial_positions, + const bool transpose) + : _dim(dim) + , _interleave(interleave) + , _scaling_type(scaling_type) + , _scaling_factor(scaling_factor) + , _base(base) + , _num_initial_positions(num_initial_positions) + , _rotary_op(dim, interleave) + , _transpose(transpose) + { + } + + void RotaryEmbeddings::apply(StorageView& x, const dim_t offset, bool apply) { + const Device device = x.device(); + const DataType dtype = x.dtype(); + const dim_t max_time = _transpose ? x.dim(-2) : x.dim(-3); + const dim_t dim = _dim == 0 ? x.dim(-1) : _dim; + + if (!_sin || offset + max_time > _sin.dim(0)) { + const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0; + const dim_t new_num_positions = std::max(offset + max_time, cur_num_positions + _num_initial_positions); + initialize(new_num_positions, dim, device, dtype); + } + if (!apply) + return; + + StorageView sin(dtype, device); + StorageView cos(dtype, device); + TYPE_DISPATCH(dtype, + { + sin.view(_sin.index({offset, 0}), {max_time, dim}); + cos.view(_cos.index({offset, 0}), {max_time, dim}); + }); + + StorageView y(dtype, device); + _rotary_op(x, sin, cos, y, _transpose); + x = std::move(y); + } + + void RotaryEmbeddings::initialize(const dim_t num_positions, + const dim_t dim, + const Device device, + const DataType dtype) { + StorageView inv_freq({1, dim / 2}); + for (dim_t i = 0; i < inv_freq.size(); ++i) + inv_freq.at(i) = 1.f / std::pow(_base, float(i * 2) / float(dim)); + if (inv_freq.device() != device) + inv_freq = inv_freq.to(device); + + StorageView t({num_positions, 1}); + for (dim_t i = 0; i < t.size(); ++i) + t.at(i) = _scaling_type == RotaryScalingType::None ? i : float(i) / _scaling_factor; + if (t.device() != device) + t = t.to(device); + + StorageView freqs(device); + ops::MatMul()(t, inv_freq, freqs); + + if (_interleave) + freqs.expand_dims(-1); + + StorageView emb(device); + ops::Concat(-1)({&freqs, &freqs}, emb); + + if (_interleave) + emb.reshape({num_positions, dim}); + + StorageView sin(device); + ops::Sin()(emb, sin); + if (sin.dtype() == dtype) + _sin = std::move(sin); + else + _sin = sin.to(dtype); + + StorageView cos(device); + ops::Cos()(emb, cos); + if (cos.dtype() == dtype) + _cos = std::move(cos); + else + _cos = cos.to(dtype); + } + + + Alibi::Alibi(const bool use_positive_positions, const bool scale_alibi, const dim_t num_initial_positions) + : _use_positive_positions(use_positive_positions) + , _num_initial_positions(num_initial_positions) + , _scale_alibi(scale_alibi) + , _alibi_op(use_positive_positions) + { + } + + void Alibi::apply(StorageView& x, const float scale) { + const dim_t cur_length = _alibi ? _alibi.dim(-1) : 0; + const dim_t key_length = x.dim(-1); + + if (key_length > cur_length) { + const dim_t num_heads = x.dim(1); + const dim_t new_length = cur_length + _num_initial_positions; + _alibi = build_alibi(num_heads, new_length, _use_positive_positions, _scale_alibi ? scale : 1); + _alibi.move_to(x.device(), x.dtype()); + } + + _alibi_op(x, _alibi, x); + } + + } +} diff --git a/src/layers/flash_attention.cc b/src/layers/flash_attention.cc new file mode 100644 index 000000000..d676bb661 --- /dev/null +++ b/src/layers/flash_attention.cc @@ -0,0 +1,177 @@ +#include "ctranslate2/layers/flash_attention.h" + +namespace ctranslate2 { + namespace layers { + FlashMultiHeadAttention::FlashMultiHeadAttention(const models::Model& model, + const std::string& scope, + dim_t num_heads, + bool self_attention, + bool pre_norm, + bool is_decoder, + Alibi* alibi) + : AttentionLayer(model, scope, num_heads, self_attention, pre_norm, is_decoder, alibi, true) + , _cache_time_dim(1) + { + ERROR_CHECK((self_attention), "FlashAttention only supports the self-attention"); + } + + void FlashMultiHeadAttention::operator()(const StorageView& queries, + const StorageView& values, + const StorageView* values_lengths, + StorageView& output, + StorageView* cached_keys, + StorageView* cached_values, + StorageView* attention, + const Padder* queries_padder, + const Padder* values_padder, + bool return_normalized_attention, + StorageView* position_bias, + dim_t offset) const { + const Device device = queries.device(); + const DataType dtype = queries.dtype(); + + StorageView fused_proj(dtype, device); + StorageView queries_proj(dtype, device); + StorageView keys_proj(dtype, device); + StorageView values_proj(dtype, device); + + const StorageView* q = &queries; + if (_layer_norm && _pre_norm) { + (*_layer_norm)(queries, queries_proj); + q = &queries_proj; + } + + _linear[0](*q, fused_proj); + + dim_t beam_size = 1; + + bool prefilling = (_sliding_window > 0 && values_lengths); + + if (_num_heads_kv < _num_heads) { + if (queries_padder) + queries_padder->add_padding(fused_proj); + + const ops::Split split_op(2, {_d_model, _num_heads_kv * _d_head, _num_heads_kv * _d_head}); + split_op(fused_proj, queries_proj, keys_proj, values_proj); + + split_heads(queries_proj, _num_heads); + split_heads(keys_proj, _num_heads_kv); + split_heads(values_proj, _num_heads_kv); + } else { + split_heads(fused_proj, 3 * _num_heads, queries_padder); + ops::Split(2)(fused_proj, queries_proj, keys_proj, values_proj); + } + + if (_rotary_embeddings) { + _rotary_embeddings->apply(queries_proj, offset, offset == 0); + _rotary_embeddings->apply(keys_proj, offset, offset == 0); + } + + if (cached_keys != nullptr) { + if (cached_keys->empty()) { + *cached_keys = std::move(keys_proj); + *cached_values = std::move(values_proj); + } else if (cached_keys->dim(_cache_time_dim) <= offset) { + const ops::Concat concat_op(_cache_time_dim); + auto shape = cached_keys->shape(); + shape[_cache_time_dim] = _offset_free_space; + StorageView empty_storage(std::move(shape), dtype, device); + StorageView& tmp = fused_proj; // Reuse storage. + tmp = std::move(*cached_keys); + concat_op({&tmp, &empty_storage}, *cached_keys); + tmp = std::move(*cached_values); + concat_op({&tmp, &empty_storage}, *cached_values); + + if (!prefilling && _sliding_window > 0 && (offset / (_sliding_window - 1)) >= 1) { + // only for generation + const ops::Slide slide_op(_cache_time_dim, 1, cached_keys->shape()[_cache_time_dim] - 1); + slide_op(*cached_keys, tmp); + *cached_keys = std::move(tmp); + slide_op(*cached_values, tmp); + *cached_values = std::move(tmp); + } + } + } + + if (cached_keys && offset == 0) { + keys_proj.shallow_copy(*cached_keys); + values_proj.shallow_copy(*cached_values); + } + + StorageView* rotary_cos = nullptr; + StorageView* rotary_sin = nullptr; + bool rotary_interleaved = false; + if (_rotary_embeddings && offset > 0) { + rotary_cos = &(_rotary_embeddings->get_cos()); + rotary_sin = &(_rotary_embeddings->get_sin()); + rotary_interleaved = _rotary_embeddings->get_interleave(); + } + + // init output + StorageView context(dtype, device); + ops::FlashAttention fl_attn_ops(_queries_scale, _sliding_window); + fl_attn_ops(queries_proj, keys_proj, values_proj, context, cached_keys, cached_values, attention, + return_normalized_attention, rotary_cos, rotary_sin, rotary_interleaved, nullptr/*alibli*/, offset); + + if (prefilling && cached_keys && cached_keys->shape()[_cache_time_dim] > _sliding_window) { + // set only last sliding_window tokens to cached_keys and cached_values after computing attention + const ops::Slide slide_op(_cache_time_dim, cached_keys->shape()[_cache_time_dim] - _sliding_window, _sliding_window); + StorageView tmp(dtype, device); + slide_op(*cached_keys, tmp); + *cached_keys = std::move(tmp); + slide_op(*cached_values, tmp); + *cached_values = std::move(tmp); + } + combine_heads(context, _num_heads, queries_padder, beam_size); + + _linear.back()(context, output); + if (_tensor_parallel) { + StorageView tmp(output.shape(), output.dtype(), output.device()); + ops::ReduceAll ops_reduce_all(ops::ReduceAll::RED_OP::SUM); + ops_reduce_all(output, tmp); + output = std::move(tmp); + } + if (_layer_norm) { + ops::Add()(queries, output, output); + + if (!_pre_norm) + (*_layer_norm)(output, output); + } + } + void FlashMultiHeadAttention::split_heads(StorageView& x, + dim_t num_heads, + const Padder* padder, + dim_t beam_size) { + if (padder) + padder->add_padding(x); + + if (beam_size > 1) + x.reshape({x.dim(0) / beam_size, beam_size, x.dim(2)}); + + // x has shape [batch_size, time, depth] + const dim_t batch_size = x.dim(0); + const dim_t time = x.dim(1); + const dim_t head_dim = x.dim(2) / num_heads; + + x.reshape({batch_size, time, num_heads, head_dim}); + } + + void FlashMultiHeadAttention::combine_heads(StorageView& x, + dim_t num_heads, + const Padder* padder, + dim_t beam_size) { + // x has shape [batch_size, num_heads, time, head_dim] + const dim_t batch_size = x.dim(0); + const dim_t time = x.dim(1); + const dim_t depth = x.dim(3) * num_heads; + + x.reshape({batch_size, time, depth}); + + if (beam_size > 1) + x.reshape({batch_size * beam_size, 1, depth}); + + if (padder) + padder->remove_padding(x); + } + } +} \ No newline at end of file diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 97b5669c1..291101eae 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -59,12 +59,17 @@ namespace ctranslate2 { const std::string& scope, const dim_t num_heads, const bool pre_norm, - const ops::ActivationType activation_type) - : _self_attention(model, + const ops::ActivationType activation_type, + const bool use_flash_attention) + : _self_attention(!use_flash_attention ? std::unique_ptr(new MultiHeadAttention(model, scope + "/self_attention", num_heads, /*self_attention=*/true, - pre_norm) + pre_norm)) : std::unique_ptr(new FlashMultiHeadAttention(model, + scope + "/self_attention", + num_heads, + /*self_attention=*/true, + pre_norm))) , _ff(model, scope + "/ffn", pre_norm, activation_type) { } @@ -75,17 +80,18 @@ namespace ctranslate2 { StorageView* position_bias) const { PROFILE("TransformerEncoderLayer"); StorageView context(input.dtype(), input.device()); - _self_attention(input, - input, - lengths, - context, - nullptr, - nullptr, - nullptr, - padder, - padder, - true, - position_bias); + if (_self_attention) + (*_self_attention)(input, + input, + lengths, + context, + nullptr, + nullptr, + nullptr, + padder, + padder, + true, + position_bias); _ff(context, output); } @@ -95,14 +101,21 @@ namespace ctranslate2 { const dim_t num_heads, const bool pre_norm, const ops::ActivationType activation_type, + const bool use_flash_attention, Alibi* alibi) - : _self_attention(model, + : _self_attention(!use_flash_attention ? std::unique_ptr(new MultiHeadAttention(model, + scope + "/self_attention", + num_heads, + /*self_attention=*/true, + pre_norm, + /*is_decoder=*/true, + alibi)) : std::unique_ptr(new FlashMultiHeadAttention(model, scope + "/self_attention", num_heads, /*self_attention=*/true, pre_norm, /*is_decoder=*/true, - alibi) + alibi))) , _shared_layer_norm(build_optional_layer(model, scope + "/shared_layer_norm")) , _input_layer_norm(build_optional_layer(model, scope + "/input_layer_norm")) , _post_attention_layer_norm(build_optional_layer( @@ -148,7 +161,8 @@ namespace ctranslate2 { (*_input_layer_norm)(input, hidden); StorageView attn(dtype, device); - _self_attention(hidden, + if (_self_attention) + (*_self_attention)(hidden, hidden, input_length, attn, @@ -171,8 +185,8 @@ namespace ctranslate2 { return; } - - _self_attention(input, + if (_self_attention) + (*_self_attention)(input, input, input_length, output, @@ -197,7 +211,8 @@ namespace ctranslate2 { input_padder, memory_padder, return_normalized_attention); - } else { + } + else { context = std::move(output); } @@ -249,6 +264,7 @@ namespace ctranslate2 { , _compute_type(model.effective_compute_type()) , _layernorm_embedding(build_optional_layer(model, scope + "/layernorm_embedding")) , _output_norm(build_optional_layer(model, scope + "/layer_norm")) + , _use_flash_attention(model.use_flash_attention()) , _layers(build_layers_list( model, scope + "/layer", @@ -335,12 +351,14 @@ namespace ctranslate2 { , _project_in(build_optional_layer(model, scope + "/project_in")) , _project_out(build_optional_layer(model, scope + "/project_out")) , _alibi(make_alibi(model, scope)) + , _use_flash_attention(model.use_flash_attention()) , _layers(build_layers_list( model, scope + "/layer", _num_heads, model.get_flag_with_default(scope + "/pre_norm", true), model.get_enum_value(scope + "/activation"), + _use_flash_attention, _alibi.get())) , _position_encoder(_layers.front()->get_self_attention().has_positional_embeddings() ? nullptr @@ -569,7 +587,7 @@ namespace ctranslate2 { while (true) { dim_t prompt_size = layer_in.dim(1); - if (_sliding_window == 0 || prompt_size <= _sliding_window) { + if (_sliding_window == 0 || prompt_size <= _sliding_window || _use_flash_attention) { layer_ins.push_back(std::move(layer_in)); break; } diff --git a/src/models/model.cc b/src/models/model.cc index 97bf3d1b5..7e18ff223 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -190,6 +190,8 @@ namespace ctranslate2 { DataType weight_dtype = DataType::FLOAT32; DataType float_dtype = DataType::FLOAT32; std::tie(weight_dtype, float_dtype) = compute_type_to_data_type(_effective_compute_type); + if (_use_flash_attention && (float_dtype != DataType::FLOAT16 && float_dtype != DataType::BFLOAT16)) + throw std::runtime_error("FlashAttention only support fp16 and bf16 data type"); const auto variable_index = _variable_index; for (auto& variable_pair : variable_index) { @@ -533,15 +535,17 @@ namespace ctranslate2 { Device device, int device_index, ComputeType compute_type, + bool use_flash_attention, bool tensor_parallel) { ModelFileReader model_reader(path); - return load(model_reader, device, device_index, compute_type, tensor_parallel); + return load(model_reader, device, device_index, compute_type, use_flash_attention, tensor_parallel); } std::shared_ptr Model::load(ModelReader& model_reader, Device device, int device_index, ComputeType compute_type, + bool use_flash_attention, bool tensor_parallel) { { // Log the system configuration the first time a model is loaded. @@ -585,6 +589,7 @@ namespace ctranslate2 { auto model = create_model(spec); model->_binary_version = binary_version; model->_spec_revision = spec_revision; + model->_use_flash_attention = use_flash_attention; model->_tensor_parallel = tensor_parallel; check_version(spec_revision, model->current_spec_revision(), "revision"); @@ -793,23 +798,36 @@ namespace ctranslate2 { throw std::invalid_argument("Cannot use multiple GPUs with different Compute Capabilities " "for the same model"); if (tensor_parallel && device != Device::CUDA) { - throw std::invalid_argument("Tensor Parallel mode can run only on cuda"); + throw std::invalid_argument("Tensor Parallel mode can run only on cuda"); } -#endif - - std::vector> models; if (tensor_parallel && (device_indices.size() > 1)) { spdlog::warn("Running model in mode tensor parallel does not support" " running independently a model in each device"); } + bool is_sm8x = false; + bool is_sm90 = false; + if (device == Device::CUDA) { + int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); + auto dprops = ctranslate2::cuda::get_device_properties(device_id); + is_sm8x = dprops.major == 8 && dprops.minor >= 0; + is_sm90 = dprops.major == 9 && dprops.minor == 0; + } + if (use_flash_attention && (device != Device::CUDA || (!is_sm8x && !is_sm90))) { + throw std::invalid_argument("FlashAttention only supports Ampere GPUs or newer."); + } +#endif + + std::vector> models; + models.reserve(device_indices.size() * num_replicas_per_device); for (const size_t device_index : device_indices) { std::shared_ptr model; if (models.empty()) - model = Model::load(*model_reader, device, device_index, compute_type, tensor_parallel); + model = Model::load(*model_reader, device, device_index, compute_type, + use_flash_attention, tensor_parallel); else model = models.back()->copy_to(device, device_index); diff --git a/src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..c9888029f --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..9f7a03952 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..d326c5f8d --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..887c32a37 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..f59ad95ed --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..ffc3026eb --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..d63afa5df --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..d46bfe21d --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..51a03e9fb --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..4573f7dbf --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..f8993fa43 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..f6e707ce4 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..cc209767f --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..9c9bd4195 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..e7eccc85e --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..930cf3685 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} diff --git a/src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..247660ede --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..84ecbbf88 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..82cdd1d01 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..b2644e0d0 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..6af24a602 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..430ded176 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..6a975e41e --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..a1053fb58 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..5427b6897 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..bc067fa68 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..bf33ab5c9 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..04127ed7a --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..22e62d74f --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..6d26e6bb0 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..dd59cab96 --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu b/src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..2f0202a7e --- /dev/null +++ b/src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "ctranslate2/ops/flash-attention/flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/src/ops/flash_attention.cc b/src/ops/flash_attention.cc new file mode 100644 index 000000000..eb6422bb0 --- /dev/null +++ b/src/ops/flash_attention.cc @@ -0,0 +1,31 @@ +#include "ctranslate2/ops/flash_attention.h" + +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + FlashAttention::FlashAttention(float queries_scale, dim_t sliding_window) + : _queries_scale(queries_scale) + ,_sliding_window(sliding_window) + { + } + + void FlashAttention::operator()(StorageView& queries, + StorageView& keys, + StorageView& values, + StorageView& output, + StorageView* cached_keys, + StorageView* cached_values, + StorageView* attention, + bool return_normalized_attention, + StorageView* rotary_cos, + StorageView* rotary_sin, + const bool rotary_interleave, + StorageView* alibi, + dim_t offset) const { + DEVICE_DISPATCH(queries.device(), compute(queries, keys, values, output, cached_keys, cached_values, + attention, return_normalized_attention, + rotary_cos, rotary_sin, rotary_interleave, alibi, offset)); + } + } +} diff --git a/src/ops/flash_attention_cpu.cc b/src/ops/flash_attention_cpu.cc new file mode 100644 index 000000000..742feeb16 --- /dev/null +++ b/src/ops/flash_attention_cpu.cc @@ -0,0 +1,24 @@ +#include "ctranslate2/ops/flash_attention.h" + +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + template<> + void FlashAttention::compute(StorageView&, + StorageView&, + StorageView&, + StorageView&, + StorageView*, + StorageView*, + StorageView*, + bool, + StorageView*, + StorageView*, + const bool, + StorageView*, + dim_t) const { + throw std::runtime_error("FlashAttention do not support for CPU"); + } + } +} \ No newline at end of file diff --git a/src/ops/flash_attention_gpu.cu b/src/ops/flash_attention_gpu.cu new file mode 100644 index 000000000..045c17849 --- /dev/null +++ b/src/ops/flash_attention_gpu.cu @@ -0,0 +1,367 @@ +#include "ctranslate2/ops/flash_attention.h" +#include "ctranslate2/ops/flash-attention/flash.h" +#include "ctranslate2/ops/flash-attention/static_switch.h" +#include "ctranslate2/ops/transpose.h" +#include "ctranslate2/ops/slide.h" +#include "cuda/utils.h" + +#include "dispatch.h" + +#ifndef M_LOG2E +#define M_LOG2E 1.4426950408889634074 +#endif + +namespace ctranslate2 { + namespace ops { + static void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + StorageView* q, + StorageView* k, + StorageView* v, + StorageView* out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float softmax_scale, + int window_size_left, + int window_size_right, + bool seqlenq_ngroups_swapped=false) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.is_bf16 = q->dtype() == DataType::BFLOAT16; + + // Set the pointers and strides. + params.q_ptr = q->buffer(); + params.k_ptr = k->buffer(); + params.v_ptr = v->buffer(); + // All stride are in elements, not bytes. + params.q_row_stride = q->stride(-3); + params.k_row_stride = k->stride(-3); + params.v_row_stride = v->stride(-3); + params.q_head_stride = q->stride(-2); + params.k_head_stride = k->stride(-2); + params.v_head_stride = v->stride(-2); + params.o_ptr = out->buffer(); + params.o_row_stride = out->stride(-3); + params.o_head_stride = out->stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q->stride(0); + params.k_batch_stride = k->stride(0); + params.v_batch_stride = v->stride(0); + params.o_batch_stride = out->stride(0); + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // Set this to probability of keeping an element to simplify things. + // not use dropout + params.p_dropout = 1.f; + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; + } + + // Find the number of splits that maximizes the occupancy. For example, if we have + // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is + // better than having 3 splits (efficiency = 0.67). However, we also don't want too many + // splits as that would incur more HBM reads/writes. + // So we find the best efficiency, then find the smallest number of splits that gets 85% + // of the best efficiency. + static int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { continue; } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; + } + + static void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, + const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, + const int head_size_rounded, + const int num_splits, cudaDeviceProp *dprops) { + + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + params.num_splits = num_splits; + if (num_splits < 1) { + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); + } + TENSOR_CHECK((params.num_splits <= 128), "[FlashAttention] num_splits > 128 not supported"); + } + + void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); + } + + static const ops::Transpose transpose_op({0, 2, 1, 3}); + + template<> + void FlashAttention::compute(StorageView& queries, + StorageView& keys, + StorageView& values, + StorageView& output, + StorageView* cached_keys, + StorageView* cached_values, + StorageView* attention, + bool return_normalized_attention, + StorageView* rotary_cos, + StorageView* rotary_sin, + const bool rotary_interleave, + StorageView* alibi, + dim_t offset) const { + const Device device = queries.device(); + const DataType dtype = queries.dtype(); + StorageView rotary_cos_half(dtype, device); + StorageView rotary_sin_half(dtype, device); + + dim_t window_size_left = _sliding_window > 0 ? _sliding_window : -1; + dim_t window_size_right = _sliding_window > 0 ? 0 : -1; + + int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); + auto dprops = ctranslate2::cuda::get_device_properties(device_id); + + const auto shape = queries.shape(); + const dim_t batch_size = shape[0]; + dim_t seqlen_q = shape[1]; + dim_t num_heads = shape[2]; + const dim_t head_size_og = shape[3]; + + dim_t seqlen_k, num_heads_k; + if (offset == 0) { + seqlen_k = keys.dim(1); + num_heads_k = keys.dim(2); + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + } else { + seqlen_k = cached_keys->dim(1); + num_heads_k = cached_keys->dim(2); + } + + // causal=true is the same as causal=false in this case + bool is_causal = true; + if (seqlen_q == 1 && !alibi) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + StorageView tmp(dtype, device); + transpose_op(queries.reshape({batch_size, num_heads_k, ngroups, head_size_og}), tmp); + queries = std::move(tmp); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + if (offset > 0) { + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + StorageView softmax_lse({batch_size, num_heads, seqlen_q}, DataType::FLOAT32, device); + output.resize(queries.shape()); + if (attention && return_normalized_attention) { + attention->resize({batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}); + } + bool force_split_kernel = false; + StorageView seqlens_k({batch_size}, static_cast(offset), device); + + Flash_fwd_params params; + if (offset == 0) { + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + &queries, &keys, &values, &output, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + (return_normalized_attention && attention) ? attention->buffer() : /*p_ptr=*/nullptr, + softmax_lse.buffer(), + _queries_scale, + window_size_left, + window_size_right); + + // set params splitkv + set_params_splitkv(params, batch_size, num_heads, + head_size, seqlen_k, seqlen_q, + head_size_rounded, /*num_splits*/0, &dprops); + } + else { + const int page_block_size = 1; + + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + &queries, cached_keys, cached_values, &output, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse.buffer(), + _queries_scale, + window_size_left, + window_size_right); + + int seqlen_knew = keys.dim(1); + params.seqlen_knew = seqlen_knew; + params.knew_ptr = keys.buffer(); + params.vnew_ptr = values.buffer(); + // All stride are in elements, not bytes. + params.knew_batch_stride = keys.stride(0); + params.vnew_batch_stride = values.stride(0); + params.knew_row_stride = keys.stride(-3); + params.vnew_row_stride = values.stride(-3); + params.knew_head_stride = keys.stride(-2); + params.vnew_head_stride = values.stride(-2); + params.cu_seqlens_k = static_cast(seqlens_k.buffer()); + params.is_seqlens_k_cumulative = false; + + if (rotary_cos && rotary_sin) { + params.rotary_dim = rotary_cos->dim(1); + const ops::Slide slide_op(1, 0, params.rotary_dim / 2); + slide_op(*rotary_cos, rotary_cos_half); + slide_op(*rotary_sin, rotary_sin_half); + params.rotary_cos_ptr = rotary_cos_half.buffer(); + params.rotary_sin_ptr = rotary_sin_half.buffer(); + params.is_rotary_interleaved = rotary_interleave; + } + else + params.rotary_dim = 0; + + set_params_splitkv(params, batch_size, num_heads, + head_size, seqlen_k, seqlen_q, + head_size_rounded, /*num_splits*/0, &dprops); + params.page_block_size = page_block_size; + force_split_kernel = true; + } + + StorageView softmax_lse_accum(DataType::FLOAT32, device); + StorageView out_accum(DataType::FLOAT32, device); + if (params.num_splits > 1) { + softmax_lse_accum.resize({params.num_splits, batch_size, num_heads, seqlen_q}); + out_accum.resize({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}); + params.softmax_lseaccum_ptr = softmax_lse_accum.buffer(); + params.oaccum_ptr = out_accum.buffer(); + } + params.alibi_slopes_ptr = nullptr; + + cudaStream_t stream = ctranslate2::cuda::get_cuda_stream(); + run_mha_fwd(params, stream, force_split_kernel); + + if (seqlenq_ngroups_swapped) { + StorageView tmp(dtype, device); + transpose_op(output, tmp); + output = std::move(tmp); + output.reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + } + } +} diff --git a/src/ops/rotary.cc b/src/ops/rotary.cc index 03bd11459..56455dd9b 100644 --- a/src/ops/rotary.cc +++ b/src/ops/rotary.cc @@ -14,13 +14,14 @@ namespace ctranslate2 { void Rotary::operator()(const StorageView& input, const StorageView& sin, const StorageView& cos, - StorageView& output) const { + StorageView& output, + bool is_transposed) const { PROFILE("Rotary"); output.resize_as(input); DEVICE_AND_FLOAT_DISPATCH("Rotary", input.device(), input.dtype(), - (compute(input, sin, cos, output))); + (compute(input, sin, cos, output, is_transposed))); } } diff --git a/src/ops/rotary_cpu.cc b/src/ops/rotary_cpu.cc index bdf35a5ae..8d78ce933 100644 --- a/src/ops/rotary_cpu.cc +++ b/src/ops/rotary_cpu.cc @@ -43,8 +43,9 @@ namespace ctranslate2 { void Rotary::compute(const StorageView& input, const StorageView& sin, const StorageView& cos, - StorageView& output) const { - const dim_t max_time = input.dim(-2); + StorageView& output, + bool is_transposed) const { + const dim_t max_time = is_transposed ? input.dim(-2) : input.dim(-3); const dim_t depth = input.dim(-1); const dim_t batch_size = input.size() / (max_time * depth); const dim_t ndims = _ndims == 0 ? depth : _ndims; @@ -65,7 +66,8 @@ namespace ctranslate2 { Rotary::compute(const StorageView&, \ const StorageView&, \ const StorageView&, \ - StorageView&) const; + StorageView&, \ + bool) const; DECLARE_IMPL(float) diff --git a/src/ops/rotary_gpu.cu b/src/ops/rotary_gpu.cu index 511608ce0..5451eb231 100644 --- a/src/ops/rotary_gpu.cu +++ b/src/ops/rotary_gpu.cu @@ -30,9 +30,11 @@ namespace ctranslate2 { const T* cos, T* y, const cuda::index_t max_time, + const cuda::index_t head_size, const cuda::index_t ndims, - const cuda::index_t depth) { - const auto time = blockIdx.x % max_time; + const cuda::index_t depth, + const bool transpose) { + const auto time = transpose ? blockIdx.x % max_time : blockIdx.x / head_size; const auto middle = ndims / 2; x += blockIdx.x * depth; @@ -57,8 +59,10 @@ namespace ctranslate2 { void Rotary::compute(const StorageView& input, const StorageView& sin, const StorageView& cos, - StorageView& output) const { - const dim_t max_time = input.dim(-2); + StorageView& output, + bool is_transposed) const { + const dim_t max_time = is_transposed ? input.dim(-2) : input.dim(-3); + const dim_t head_size = is_transposed ? input.dim(-3) : input.dim(-2); const dim_t depth = input.dim(-1); const dim_t ndims = _ndims == 0 ? depth : _ndims; @@ -74,10 +78,10 @@ namespace ctranslate2 { if (_interleave) rotary_kernel<<>>( - x, s, c, y, max_time, ndims, depth); + x, s, c, y, max_time, head_size, ndims, depth, is_transposed); else rotary_kernel<<>>( - x, s, c, y, max_time, ndims, depth); + x, s, c, y, max_time, head_size, ndims, depth, is_transposed); } #define DECLARE_IMPL(T) \ @@ -85,7 +89,8 @@ namespace ctranslate2 { Rotary::compute(const StorageView&, \ const StorageView&, \ const StorageView&, \ - StorageView&) const; + StorageView&, \ + bool) const; DECLARE_IMPL(float) DECLARE_IMPL(float16_t) From 71108153f1951b8b1de4b39cf26c0c1b2ce668d4 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Tue, 9 Apr 2024 00:51:06 +0200 Subject: [PATCH 2/3] fix submodule cutlass --- third_party/cutlass | 1 + 1 file changed, 1 insertion(+) create mode 160000 third_party/cutlass diff --git a/third_party/cutlass b/third_party/cutlass new file mode 160000 index 000000000..19f3cc33f --- /dev/null +++ b/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit 19f3cc33f1642b490ed7126ea0141f79c0045527 From 42ff314378eb83f34ddc9e6a7e9d4088103e4803 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Tue, 9 Apr 2024 09:52:12 +0200 Subject: [PATCH 3/3] update cutlass --- third_party/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cutlass b/third_party/cutlass index 19f3cc33f..bbe579a9e 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 19f3cc33f1642b490ed7126ea0141f79c0045527 +Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49