Skip to content

Commit

Permalink
Flash attention (OpenNMT#1651)
Browse files Browse the repository at this point in the history
* flash attention support
  • Loading branch information
minhthuc2502 authored Apr 9, 2024
1 parent 8994330 commit 7d63eea
Show file tree
Hide file tree
Showing 76 changed files with 4,727 additions and 409 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@
[submodule "third_party/ruy"]
path = third_party/ruy
url = https://github.com/google/ruy.git
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
51 changes: 49 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ set(SOURCES
src/env.cc
src/filesystem.cc
src/generator.cc
src/layers/attention_layer.cc
src/layers/attention.cc
src/layers/flash_attention.cc
src/layers/common.cc
src/layers/decoder.cc
src/layers/transformer.cc
Expand Down Expand Up @@ -141,6 +143,8 @@ set(SOURCES
src/ops/cos.cc
src/ops/dequantize.cc
src/ops/dequantize_cpu.cc
src/ops/flash_attention.cc
src/ops/flash_attention_cpu.cc
src/ops/gather.cc
src/ops/gather_cpu.cc
src/ops/gelu.cc
Expand Down Expand Up @@ -447,9 +451,8 @@ if (WITH_CUDA)
else()
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=/MT$<$<CONFIG:Debug>:d>")
endif()
else()
list(APPEND CUDA_NVCC_FLAGS "-std=c++17")
endif()
list(APPEND CUDA_NVCC_FLAGS "-std=c++17")
if(OpenMP_CXX_FOUND)
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=${OpenMP_CXX_FLAGS}")
endif()
Expand All @@ -469,6 +472,11 @@ if (WITH_CUDA)
list(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
set(CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER})

# flags for flash attention
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda")
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")

message(STATUS "NVCC host compiler: ${CUDA_HOST_COMPILER}")
message(STATUS "NVCC compilation flags: ${CUDA_NVCC_FLAGS}")

Expand All @@ -482,6 +490,12 @@ if (WITH_CUDA)
cuda_include_directories(${THRUST_INCLUDE_DIRS})
list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${THRUST_INCLUDE_DIRS})

set(CUTLASS_INCLUDE_DIRS
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/include
)
cuda_include_directories(${CUTLASS_INCLUDE_DIRS})
list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIRS})

if(WITH_CUDNN)
# Find cuDNN includes.
find_path(CUDNN_INCLUDE_DIR NAMES cudnn.h HINTS ${CUDA_TOOLKIT_ROOT_DIR}/include)
Expand Down Expand Up @@ -531,6 +545,7 @@ if (WITH_CUDA)
src/ops/concat_split_slide_gpu.cu
src/ops/conv1d_gpu.cu
src/ops/dequantize_gpu.cu
src/ops/flash_attention_gpu.cu
src/ops/gather_gpu.cu
src/ops/gumbel_max_gpu.cu
src/ops/layer_norm_gpu.cu
Expand All @@ -544,6 +559,38 @@ if (WITH_CUDA)
src/ops/topp_mask_gpu.cu
src/ops/quantize_gpu.cu
src/ops/nccl_ops_gpu.cu
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
)
elseif(WITH_CUDNN)
message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
Expand Down
98 changes: 15 additions & 83 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "ctranslate2/layers/common.h"
#include "ctranslate2/layers/attention_layer.h"
#include "ctranslate2/padder.h"

namespace ctranslate2 {
Expand All @@ -13,7 +13,7 @@ namespace ctranslate2 {
class RotaryEmbeddings;
class Alibi;

class MultiHeadAttention : public Layer
class MultiHeadAttention : public AttentionLayer
{
public:
MultiHeadAttention(const models::Model& model,
Expand All @@ -25,7 +25,7 @@ namespace ctranslate2 {
Alibi* alibi = nullptr);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& queries,
virtual void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
StorageView& output,
Expand All @@ -36,95 +36,27 @@ namespace ctranslate2 {
const Padder* values_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr,
dim_t offset = 0) const;
dim_t offset = 0) const override;

bool has_positional_embeddings() const {
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;
virtual bool has_positional_embeddings() const override {
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;
}

bool multi_query() const {
return _multi_query;
}

static StorageView prepare_length_mask(const StorageView& lengths,
const dim_t num_heads,
const dim_t num_queries,
const bool mask_future = false,
const bool multi_query = false);

private:
const bool _tensor_parallel;
const dim_t _num_heads;
const bool _self_attention;
const bool _is_decoder;
const std::vector<Dense> _linear;
const dim_t _d_model;
const dim_t _d_head;
const bool _pre_norm;
const std::unique_ptr<const LayerNorm> _layer_norm;
const std::unique_ptr<RotaryEmbeddings> _rotary_embeddings;
Alibi* _alibi;
static void split_heads(StorageView& x,
dim_t num_heads,
const Padder* padder = nullptr,
dim_t beam_size = 1);

static void combine_heads(StorageView& x,
dim_t num_heads,
const Padder* padder = nullptr,
dim_t beam_size = 1);
const StorageView* _relative_attention_bias;
const StorageView* _relative_position_keys;
const StorageView* _relative_position_values;
dim_t _maximum_relative_position;
const float _queries_scale;
const bool _multi_query;
const dim_t _num_heads_kv;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
const dim_t _sliding_window;
};

enum class RotaryScalingType {
None = -1,
Linear,
};

class RotaryEmbeddings {
public:
RotaryEmbeddings(const dim_t dim = 0,
const bool interleave = true,
const RotaryScalingType scaling_type = RotaryScalingType::None,
const float scaling_factor = 1,
const float base = 10000,
const dim_t num_initial_positions = 2048);

void apply(StorageView& x, const dim_t offset = 0);

private:
void initialize(const dim_t num_positions,
const dim_t dim,
const Device device,
const DataType dtype);

const dim_t _dim;
const bool _interleave;
const RotaryScalingType _scaling_type;
const float _scaling_factor;
const float _base;
const dim_t _num_initial_positions;
const ops::Rotary _rotary_op;

StorageView _sin;
StorageView _cos;
};


class Alibi {
public:
Alibi(const bool use_positive_positions = false, const bool scale_alibi = false, const dim_t num_initial_positions = 2048);

void apply(StorageView& x, const float scale = 1);

private:
const bool _use_positive_positions;
const dim_t _num_initial_positions;
const bool _scale_alibi;
const ops::AlibiAdd _alibi_op;

StorageView _alibi;
};

}
}
136 changes: 136 additions & 0 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#pragma once

#include "ctranslate2/layers/common.h"
#include "ctranslate2/padder.h"

namespace ctranslate2 {
namespace layers {
StorageView make_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t max_position);

class RotaryEmbeddings;
class Alibi;

class AttentionLayer : public Layer
{
public:
AttentionLayer(const models::Model& model,
const std::string& scope,
dim_t num_heads,
bool self_attention,
bool pre_norm = true,
bool is_decoder = false,
Alibi* alibi = nullptr,
bool is_flash_attn = false);
virtual ~AttentionLayer() {};
DataType output_type() const override;
dim_t output_size() const override;
virtual void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
StorageView& output,
StorageView* cached_keys = nullptr,
StorageView* cached_values = nullptr,
StorageView* attention = nullptr,
const Padder* queries_padder = nullptr,
const Padder* values_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr,
dim_t offset = 0) const = 0;

virtual bool has_positional_embeddings() const = 0;

bool multi_query() const {
return _multi_query;
}

static StorageView prepare_length_mask(const StorageView& lengths,
const dim_t num_heads,
const dim_t num_queries,
const bool mask_future = false,
const bool multi_query = false);

protected:
const bool _tensor_parallel;
const dim_t _num_heads;
const bool _self_attention;
const bool _is_decoder;
const std::vector<Dense> _linear;
const dim_t _d_model;
const dim_t _d_head;
const bool _pre_norm;
const std::unique_ptr<const LayerNorm> _layer_norm;
const std::unique_ptr<RotaryEmbeddings> _rotary_embeddings;
Alibi* _alibi;
const float _queries_scale;
const bool _multi_query;
const dim_t _num_heads_kv;
const dim_t _sliding_window;
};

enum class RotaryScalingType {
None = -1,
Linear,
};

class RotaryEmbeddings {
public:
RotaryEmbeddings(const dim_t dim = 0,
const bool interleave = true,
const RotaryScalingType scaling_type = RotaryScalingType::None,
const float scaling_factor = 1,
const float base = 10000,
const dim_t num_initial_positions = 2048,
const bool transpose = true);

void apply(StorageView& x, const dim_t offset = 0, bool apply = true);

StorageView& get_cos() {
return _cos;
}

StorageView& get_sin() {
return _sin;
}

bool get_interleave() const {
return _interleave;
}

private:
void initialize(const dim_t num_positions,
const dim_t dim,
const Device device,
const DataType dtype);

const dim_t _dim;
const bool _interleave;
const RotaryScalingType _scaling_type;
const float _scaling_factor;
const float _base;
const dim_t _num_initial_positions;
const ops::Rotary _rotary_op;
const bool _transpose;

StorageView _sin;
StorageView _cos;
};


class Alibi {
public:
Alibi(const bool use_positive_positions = false, const bool scale_alibi = false, const dim_t num_initial_positions = 2048);

void apply(StorageView& x, const float scale = 1);

private:
const bool _use_positive_positions;
const dim_t _num_initial_positions;
const bool _scale_alibi;
const ops::AlibiAdd _alibi_op;

StorageView _alibi;
};
}
}
Loading

0 comments on commit 7d63eea

Please sign in to comment.