Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention #1651

Merged
merged 3 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@
[submodule "third_party/ruy"]
path = third_party/ruy
url = https://github.com/google/ruy.git
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
51 changes: 49 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ set(SOURCES
src/env.cc
src/filesystem.cc
src/generator.cc
src/layers/attention_layer.cc
src/layers/attention.cc
src/layers/flash_attention.cc
src/layers/common.cc
src/layers/decoder.cc
src/layers/transformer.cc
Expand Down Expand Up @@ -141,6 +143,8 @@ set(SOURCES
src/ops/cos.cc
src/ops/dequantize.cc
src/ops/dequantize_cpu.cc
src/ops/flash_attention.cc
src/ops/flash_attention_cpu.cc
src/ops/gather.cc
src/ops/gather_cpu.cc
src/ops/gelu.cc
Expand Down Expand Up @@ -447,9 +451,8 @@ if (WITH_CUDA)
else()
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=/MT$<$<CONFIG:Debug>:d>")
endif()
else()
list(APPEND CUDA_NVCC_FLAGS "-std=c++17")
endif()
list(APPEND CUDA_NVCC_FLAGS "-std=c++17")
if(OpenMP_CXX_FOUND)
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=${OpenMP_CXX_FLAGS}")
endif()
Expand All @@ -469,6 +472,11 @@ if (WITH_CUDA)
list(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
set(CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER})

# flags for flash attention
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda")
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")

message(STATUS "NVCC host compiler: ${CUDA_HOST_COMPILER}")
message(STATUS "NVCC compilation flags: ${CUDA_NVCC_FLAGS}")

Expand All @@ -482,6 +490,12 @@ if (WITH_CUDA)
cuda_include_directories(${THRUST_INCLUDE_DIRS})
list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${THRUST_INCLUDE_DIRS})

set(CUTLASS_INCLUDE_DIRS
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/include
)
cuda_include_directories(${CUTLASS_INCLUDE_DIRS})
list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIRS})

if(WITH_CUDNN)
# Find cuDNN includes.
find_path(CUDNN_INCLUDE_DIR NAMES cudnn.h HINTS ${CUDA_TOOLKIT_ROOT_DIR}/include)
Expand Down Expand Up @@ -531,6 +545,7 @@ if (WITH_CUDA)
src/ops/concat_split_slide_gpu.cu
src/ops/conv1d_gpu.cu
src/ops/dequantize_gpu.cu
src/ops/flash_attention_gpu.cu
src/ops/gather_gpu.cu
src/ops/gumbel_max_gpu.cu
src/ops/layer_norm_gpu.cu
Expand All @@ -544,6 +559,38 @@ if (WITH_CUDA)
src/ops/topp_mask_gpu.cu
src/ops/quantize_gpu.cu
src/ops/nccl_ops_gpu.cu
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
)
elseif(WITH_CUDNN)
message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
Expand Down
98 changes: 15 additions & 83 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "ctranslate2/layers/common.h"
#include "ctranslate2/layers/attention_layer.h"
#include "ctranslate2/padder.h"

namespace ctranslate2 {
Expand All @@ -13,7 +13,7 @@ namespace ctranslate2 {
class RotaryEmbeddings;
class Alibi;

class MultiHeadAttention : public Layer
class MultiHeadAttention : public AttentionLayer
{
public:
MultiHeadAttention(const models::Model& model,
Expand All @@ -25,7 +25,7 @@ namespace ctranslate2 {
Alibi* alibi = nullptr);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& queries,
virtual void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
StorageView& output,
Expand All @@ -36,95 +36,27 @@ namespace ctranslate2 {
const Padder* values_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr,
dim_t offset = 0) const;
dim_t offset = 0) const override;

bool has_positional_embeddings() const {
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;
virtual bool has_positional_embeddings() const override {
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;
}

bool multi_query() const {
return _multi_query;
}

static StorageView prepare_length_mask(const StorageView& lengths,
const dim_t num_heads,
const dim_t num_queries,
const bool mask_future = false,
const bool multi_query = false);

private:
const bool _tensor_parallel;
const dim_t _num_heads;
const bool _self_attention;
const bool _is_decoder;
const std::vector<Dense> _linear;
const dim_t _d_model;
const dim_t _d_head;
const bool _pre_norm;
const std::unique_ptr<const LayerNorm> _layer_norm;
const std::unique_ptr<RotaryEmbeddings> _rotary_embeddings;
Alibi* _alibi;
static void split_heads(StorageView& x,
dim_t num_heads,
const Padder* padder = nullptr,
dim_t beam_size = 1);

static void combine_heads(StorageView& x,
dim_t num_heads,
const Padder* padder = nullptr,
dim_t beam_size = 1);
const StorageView* _relative_attention_bias;
const StorageView* _relative_position_keys;
const StorageView* _relative_position_values;
dim_t _maximum_relative_position;
const float _queries_scale;
const bool _multi_query;
const dim_t _num_heads_kv;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
const dim_t _sliding_window;
};

enum class RotaryScalingType {
None = -1,
Linear,
};

class RotaryEmbeddings {
public:
RotaryEmbeddings(const dim_t dim = 0,
const bool interleave = true,
const RotaryScalingType scaling_type = RotaryScalingType::None,
const float scaling_factor = 1,
const float base = 10000,
const dim_t num_initial_positions = 2048);

void apply(StorageView& x, const dim_t offset = 0);

private:
void initialize(const dim_t num_positions,
const dim_t dim,
const Device device,
const DataType dtype);

const dim_t _dim;
const bool _interleave;
const RotaryScalingType _scaling_type;
const float _scaling_factor;
const float _base;
const dim_t _num_initial_positions;
const ops::Rotary _rotary_op;

StorageView _sin;
StorageView _cos;
};


class Alibi {
public:
Alibi(const bool use_positive_positions = false, const bool scale_alibi = false, const dim_t num_initial_positions = 2048);

void apply(StorageView& x, const float scale = 1);

private:
const bool _use_positive_positions;
const dim_t _num_initial_positions;
const bool _scale_alibi;
const ops::AlibiAdd _alibi_op;

StorageView _alibi;
};

}
}
136 changes: 136 additions & 0 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#pragma once

#include "ctranslate2/layers/common.h"
#include "ctranslate2/padder.h"

namespace ctranslate2 {
namespace layers {
StorageView make_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t max_position);

class RotaryEmbeddings;
class Alibi;

class AttentionLayer : public Layer
{
public:
AttentionLayer(const models::Model& model,
const std::string& scope,
dim_t num_heads,
bool self_attention,
bool pre_norm = true,
bool is_decoder = false,
Alibi* alibi = nullptr,
bool is_flash_attn = false);
virtual ~AttentionLayer() {};
DataType output_type() const override;
dim_t output_size() const override;
virtual void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
StorageView& output,
StorageView* cached_keys = nullptr,
StorageView* cached_values = nullptr,
StorageView* attention = nullptr,
const Padder* queries_padder = nullptr,
const Padder* values_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr,
dim_t offset = 0) const = 0;

virtual bool has_positional_embeddings() const = 0;

bool multi_query() const {
return _multi_query;
}

static StorageView prepare_length_mask(const StorageView& lengths,
const dim_t num_heads,
const dim_t num_queries,
const bool mask_future = false,
const bool multi_query = false);

protected:
const bool _tensor_parallel;
const dim_t _num_heads;
const bool _self_attention;
const bool _is_decoder;
const std::vector<Dense> _linear;
const dim_t _d_model;
const dim_t _d_head;
const bool _pre_norm;
const std::unique_ptr<const LayerNorm> _layer_norm;
const std::unique_ptr<RotaryEmbeddings> _rotary_embeddings;
Alibi* _alibi;
const float _queries_scale;
const bool _multi_query;
const dim_t _num_heads_kv;
const dim_t _sliding_window;
};

enum class RotaryScalingType {
None = -1,
Linear,
};

class RotaryEmbeddings {
public:
RotaryEmbeddings(const dim_t dim = 0,
const bool interleave = true,
const RotaryScalingType scaling_type = RotaryScalingType::None,
const float scaling_factor = 1,
const float base = 10000,
const dim_t num_initial_positions = 2048,
const bool transpose = true);

void apply(StorageView& x, const dim_t offset = 0, bool apply = true);

StorageView& get_cos() {
return _cos;
}

StorageView& get_sin() {
return _sin;
}

bool get_interleave() const {
return _interleave;
}

private:
void initialize(const dim_t num_positions,
const dim_t dim,
const Device device,
const DataType dtype);

const dim_t _dim;
const bool _interleave;
const RotaryScalingType _scaling_type;
const float _scaling_factor;
const float _base;
const dim_t _num_initial_positions;
const ops::Rotary _rotary_op;
const bool _transpose;

StorageView _sin;
StorageView _cos;
};


class Alibi {
public:
Alibi(const bool use_positive_positions = false, const bool scale_alibi = false, const dim_t num_initial_positions = 2048);

void apply(StorageView& x, const float scale = 1);

private:
const bool _use_positive_positions;
const dim_t _num_initial_positions;
const bool _scale_alibi;
const ops::AlibiAdd _alibi_op;

StorageView _alibi;
};
}
}
Loading
Loading