Skip to content

Commit

Permalink
Small improvement in flash attention (#1732)
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthuc2502 authored Jun 26, 2024
1 parent 59c7dda commit 72a461a
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 81 deletions.
35 changes: 35 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,41 @@ if (WITH_CUDA)
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
)

set_source_files_properties(
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
PROPERTIES COMPILE_FLAGS "--use_fast_math")
elseif(WITH_CUDNN)
message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
else()
Expand Down
12 changes: 7 additions & 5 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ namespace ctranslate2 {
const dim_t max_position_embeddings = 0,
const bool transpose = true);

void apply(StorageView& x, const dim_t offset = 0, bool apply = true);
void apply(StorageView& x, const dim_t offset = 0, bool fa2 = false);

StorageView& get_cos() {
return _cos;
StorageView& get_cos_half() {
return *_cos_half;
}

StorageView& get_sin() {
return _sin;
StorageView& get_sin_half() {
return *_sin_half;
}

bool get_interleave() const {
Expand Down Expand Up @@ -124,6 +124,8 @@ namespace ctranslate2 {

StorageView _sin;
StorageView _cos;
std::unique_ptr<StorageView> _sin_half;
std::unique_ptr<StorageView> _cos_half;
};


Expand Down
2 changes: 1 addition & 1 deletion include/ctranslate2/layers/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace ctranslate2 {
dim_t beam_size = 1);

const dim_t _cache_time_dim;
static constexpr dim_t _offset_free_space{100};
static constexpr dim_t _offset_free_space{512};
};
}
}
51 changes: 0 additions & 51 deletions include/ctranslate2/ops/flash-attention/philox.cuh

This file was deleted.

1 change: 0 additions & 1 deletion include/ctranslate2/ops/flash-attention/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

#include <cutlass/numeric_types.h>

#include "philox.cuh"
#include "utils.h"

#ifndef M_LOG2E
Expand Down
16 changes: 14 additions & 2 deletions src/layers/attention_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ namespace ctranslate2 {
_rotary_scaling_short_factor = std::make_unique<StorageView>(_rotary_scaling_short_factor->to(Device::CPU));
}

void RotaryEmbeddings::apply(StorageView& x, const dim_t offset, bool apply) {
void RotaryEmbeddings::apply(StorageView& x, const dim_t offset, bool fa2) {
const Device device = x.device();
const DataType dtype = x.dtype();
const dim_t max_time = _transpose ? x.dim(-2) : x.dim(-3);
Expand All @@ -211,8 +211,20 @@ namespace ctranslate2 {
const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0;
const dim_t new_num_positions = std::max(offset + max_time, cur_num_positions + _num_initial_positions);
initialize(new_num_positions, dim, device, dtype);
if (fa2) {
if (!_sin_half)
{
_sin_half = std::make_unique<StorageView>(dtype, device);
_cos_half = std::make_unique<StorageView>(dtype, device);
}
const ops::Slide slide_op(1, 0, dim / 2);
slide_op(_cos, *_cos_half);
slide_op(_sin, *_sin_half);
if (offset != 0)
return;
}
}
if (!apply)
if (offset != 0 && fa2)
return;

StorageView sin(dtype, device);
Expand Down
15 changes: 8 additions & 7 deletions src/layers/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@ namespace ctranslate2 {
}

void FlashMultiHeadAttention::operator()(const StorageView& queries,
const StorageView& values,
const StorageView&,
const StorageView* values_lengths,
StorageView& output,
StorageView* cached_keys,
StorageView* cached_values,
StorageView* attention,
const Padder* queries_padder,
const Padder* values_padder,
const Padder*,
bool return_normalized_attention,
StorageView* position_bias,
StorageView*,
dim_t offset) const {
PROFILE("MultiHeadAttention");
const Device device = queries.device();
const DataType dtype = queries.dtype();

Expand Down Expand Up @@ -63,8 +64,8 @@ namespace ctranslate2 {
}

if (_rotary_embeddings) {
_rotary_embeddings->apply(queries_proj, offset, offset == 0);
_rotary_embeddings->apply(keys_proj, offset, offset == 0);
_rotary_embeddings->apply(queries_proj, offset, true);
_rotary_embeddings->apply(keys_proj, offset, true);
}

if (cached_keys != nullptr) {
Expand Down Expand Up @@ -102,8 +103,8 @@ namespace ctranslate2 {
StorageView* rotary_sin = nullptr;
bool rotary_interleaved = false;
if (_rotary_embeddings && offset > 0) {
rotary_cos = &(_rotary_embeddings->get_cos());
rotary_sin = &(_rotary_embeddings->get_sin());
rotary_cos = &(_rotary_embeddings->get_cos_half());
rotary_sin = &(_rotary_embeddings->get_sin_half());
rotary_interleaved = _rotary_embeddings->get_interleave();
}

Expand Down
1 change: 1 addition & 0 deletions src/ops/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace ctranslate2 {
const bool rotary_interleave,
StorageView* alibi,
dim_t offset) const {
PROFILE("FlashAttention");
DEVICE_DISPATCH(queries.device(), compute<D>(queries, keys, values, output, cached_keys, cached_values,
attention, return_normalized_attention,
rotary_cos, rotary_sin, rotary_interleave, alibi, offset));
Expand Down
11 changes: 3 additions & 8 deletions src/ops/flash_attention_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,6 @@ namespace ctranslate2 {
dim_t offset) const {
const Device device = queries.device();
const DataType dtype = queries.dtype();
StorageView rotary_cos_half(dtype, device);
StorageView rotary_sin_half(dtype, device);

dim_t window_size_left = _sliding_window > 0 ? _sliding_window : -1;
dim_t window_size_right = _sliding_window > 0 ? 0 : -1;
Expand Down Expand Up @@ -324,12 +322,9 @@ namespace ctranslate2 {
params.is_seqlens_k_cumulative = false;

if (rotary_cos && rotary_sin) {
params.rotary_dim = rotary_cos->dim(1);
const ops::Slide slide_op(1, 0, params.rotary_dim / 2);
slide_op(*rotary_cos, rotary_cos_half);
slide_op(*rotary_sin, rotary_sin_half);
params.rotary_cos_ptr = rotary_cos_half.buffer();
params.rotary_sin_ptr = rotary_sin_half.buffer();
params.rotary_dim = rotary_cos->dim(1) * 2;
params.rotary_cos_ptr = rotary_cos->buffer();
params.rotary_sin_ptr = rotary_sin->buffer();
params.is_rotary_interleaved = rotary_interleave;
}
else
Expand Down
13 changes: 7 additions & 6 deletions tools/benchmark_tensor_parallel/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def process_prompt(generator, max_generation_length, generated_token, prompt):
step_results = generator.generate_tokens(
prompt,
max_length=max_generation_length,
sampling_temperature=0.6,
sampling_topk=20,
sampling_temperature=0.75,
sampling_topk=1,
sampling_topp=1,
)
for step_result in step_results:
Expand All @@ -77,13 +77,13 @@ def benchmark_generation(generator,
step_results = generator.generate_tokens(
prompt_tokens[i:i + batch_size],
max_length=max_generation_length,
sampling_temperature=0.6,
sampling_topk=20,
sampling_temperature=0.75,
sampling_topk=1,
sampling_topp=1,
)
for step_result in step_results:
batch_id = step_result.batch_id
generated_token[batch_id].append(step_result.token)
generated_token[i + batch_id].append(step_result.token)
end_all = time.time()
elapsed_time = end_all - start_all
num_tokens = count_tokens(generated_token)
Expand Down Expand Up @@ -148,7 +148,8 @@ def main():
args = parser.parse_args()

print("Loading the model...")
generator = ctranslate2.Generator(args.model_path, device="cuda", tensor_parallel=True, inter_threads=2)
generator = ctranslate2.Generator(args.model_path, device="cuda", tensor_parallel=True,
flash_attention=False, inter_threads=2)
sp = spm.SentencePieceProcessor(os.path.join(args.model_path, "tokenizer.model"))

if not os.path.exists(args.src):
Expand Down

0 comments on commit 72a461a

Please sign in to comment.