Skip to content

Commit

Permalink
fix some logics to optimize and work correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthuc2502 committed Mar 28, 2024
1 parent 39a5ddd commit b7338ce
Show file tree
Hide file tree
Showing 9 changed files with 632 additions and 614 deletions.
4 changes: 3 additions & 1 deletion include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ namespace ctranslate2 {
const RotaryScalingType scaling_type = RotaryScalingType::None,
const float scaling_factor = 1,
const float base = 10000,
const dim_t num_initial_positions = 2048);
const dim_t num_initial_positions = 2048,
const bool transpose = true);

void apply(StorageView& x, const dim_t offset = 0);

Expand All @@ -105,6 +106,7 @@ namespace ctranslate2 {
const float _base;
const dim_t _num_initial_positions;
const ops::Rotary _rotary_op;
const bool _transpose;

StorageView _sin;
StorageView _cos;
Expand Down
1,116 changes: 559 additions & 557 deletions include/ctranslate2/layers/flash-attention/flash_fwd_kernel.h

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/

#pragma once

#include "static_switch.h"
Expand Down
6 changes: 4 additions & 2 deletions include/ctranslate2/ops/rotary.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace ctranslate2 {
void operator()(const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const;
StorageView& output,
bool is_transpose=true) const;

private:
const dim_t _ndims;
Expand All @@ -22,7 +23,8 @@ namespace ctranslate2 {
void compute(const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const;
StorageView& output,
bool is_transpose) const;
};

}
Expand Down
8 changes: 5 additions & 3 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -648,21 +648,23 @@ namespace ctranslate2 {
const RotaryScalingType scaling_type,
const float scaling_factor,
const float base,
const dim_t num_initial_positions)
const dim_t num_initial_positions,
const bool transpose)
: _dim(dim)
, _interleave(interleave)
, _scaling_type(scaling_type)
, _scaling_factor(scaling_factor)
, _base(base)
, _num_initial_positions(num_initial_positions)
, _rotary_op(dim, interleave)
, _transpose(transpose)
{
}

void RotaryEmbeddings::apply(StorageView& x, const dim_t offset) {
const Device device = x.device();
const DataType dtype = x.dtype();
const dim_t max_time = x.dim(-2);
const dim_t max_time = _transpose ? x.dim(-2) : x.dim(-3);
const dim_t dim = _dim == 0 ? x.dim(-1) : _dim;

if (!_sin || offset + max_time > _sin.dim(0)) {
Expand All @@ -680,7 +682,7 @@ namespace ctranslate2 {
});

StorageView y(dtype, device);
_rotary_op(x, sin, cos, y);
_rotary_op(x, sin, cos, y, _transpose);
x = std::move(y);
}

Expand Down
76 changes: 37 additions & 39 deletions src/layers/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,6 @@ namespace ctranslate2 {
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
/*if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}*/
//TENSOR_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
}

Expand Down Expand Up @@ -233,14 +227,15 @@ namespace ctranslate2 {
const dim_t time = x.dim(1);
const dim_t head_dim = x.dim(2) / num_heads;

if (time == 1) {
x.reshape({batch_size, time, num_heads, head_dim});
/*if (time == 1) {
x.reshape({batch_size, num_heads, 1, head_dim});
} else {
x.reshape({batch_size, time, num_heads, head_dim});
StorageView y(x.device(), x.dtype());
transpose_op(x, y);
x = std::move(y);
}
}*/
}

static void replicate_heads(StorageView& x, dim_t repeats) {
Expand All @@ -255,15 +250,9 @@ namespace ctranslate2 {
dim_t beam_size = 1) {
// x has shape [batch_size, num_heads, time, head_dim]
const dim_t batch_size = x.dim(0);
const dim_t time = x.dim(2);
const dim_t time = x.dim(1);
const dim_t depth = x.dim(3) * num_heads;

if (time > 1) {
StorageView y(x.device(), x.dtype());
transpose_op(x, y);
x = std::move(y);
}

x.reshape({batch_size, time, depth});

if (beam_size > 1)
Expand Down Expand Up @@ -291,7 +280,9 @@ namespace ctranslate2 {
interleave,
scaling_type,
scaling_factor,
base);
base,
2048/*num_initial*/,
false/*transpose*/);
}

FlashMultiHeadAttention::FlashMultiHeadAttention(const models::Model& model,
Expand Down Expand Up @@ -328,7 +319,7 @@ namespace ctranslate2 {
&& !_relative_attention_bias
&& !_relative_position_keys
&& !_relative_position_values)
, _cache_time_dim(_merge_time_and_head_dims ? 1 : 2)
, _cache_time_dim(1)
, _sliding_window(model.get_attribute_with_default<int32_t>(scope + "/sliding_window", 0))
{
if (_relative_position_keys)
Expand Down Expand Up @@ -430,7 +421,7 @@ namespace ctranslate2 {

} else {
split_heads(fused_proj, 3 * _num_heads, queries_padder);
ops::Split(1)(fused_proj, queries_proj, keys_proj, values_proj);
ops::Split(2)(fused_proj, queries_proj, keys_proj, values_proj);
}

if (_rotary_embeddings) {
Expand All @@ -451,9 +442,9 @@ namespace ctranslate2 {
tmp = std::move(*cached_values);
concat_op({&tmp, &values_proj}, *cached_values);

if (!prefilling && _sliding_window > 0 && cached_keys->shape()[2] > _sliding_window) {
if (!prefilling && _sliding_window > 0 && cached_keys->shape()[_cache_time_dim] > _sliding_window) {
// only for generation
const ops::Slide slide_op(2, 1, cached_keys->shape()[2] - 1);
const ops::Slide slide_op(_cache_time_dim, 1, cached_keys->shape()[_cache_time_dim] - 1);
slide_op(*cached_keys, tmp);
*cached_keys = std::move(tmp);
slide_op(*cached_values, tmp);
Expand All @@ -466,12 +457,12 @@ namespace ctranslate2 {
keys_proj.shallow_copy(*cached_keys);
values_proj.shallow_copy(*cached_values);
}
StorageView keys_proj_t(dtype, device);
StorageView values_proj_t(dtype, device);
transpose_op(queries_proj, fused_proj);
queries_proj = std::move(fused_proj);
transpose_op(keys_proj, keys_proj_t);
transpose_op(values_proj, values_proj_t);
//StorageView keys_proj_t(dtype, device);
//StorageView values_proj_t(dtype, device);
//transpose_op(queries_proj, fused_proj);
//queries_proj = std::move(fused_proj);
//transpose_op(keys_proj, keys_proj_t);
//transpose_op(values_proj, values_proj_t);

dim_t window_size_right = -1;
dim_t window_size_left = -1;
Expand Down Expand Up @@ -505,8 +496,8 @@ namespace ctranslate2 {
dim_t seqlen_q = shape[1];
dim_t num_heads = shape[2];
const dim_t head_size_og = shape[3];
const dim_t seqlen_k = keys_proj_t.dim(1);
const dim_t num_heads_k = keys_proj_t.dim(2);
const dim_t seqlen_k = keys_proj.dim(1);
const dim_t num_heads_k = keys_proj.dim(2);
//TENSOR_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256")
//TENSOR_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

Expand Down Expand Up @@ -555,7 +546,7 @@ namespace ctranslate2 {
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
&queries_proj, &keys_proj_t, &values_proj_t, &context,
&queries_proj, &keys_proj, &values_proj, &context,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
Expand All @@ -569,30 +560,37 @@ namespace ctranslate2 {
set_params_splitkv(params, batch_size, num_heads,
head_size, seqlen_k, seqlen_q,
head_size_rounded, /*num_splits*/0, &dprops);
StorageView softmax_lse_accum(DataType::FLOAT32, device);
StorageView out_accum(DataType::FLOAT32, device);
if (params.num_splits > 1) {
softmax_lse_accum.resize({params.num_splits, batch_size, num_heads, seqlen_q});
out_accum.resize({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded});
params.softmax_lseaccum_ptr = softmax_lse_accum.buffer();
params.oaccum_ptr = out_accum.buffer();
}
params.alibi_slopes_ptr = nullptr;

std::cout << "queries_proj: " << queries_proj << std::endl;
//std::cout << "keys_proj: " << keys_proj_t << std::endl;
//std::cout << "values_proj: " << values_proj_t << std::endl;
//std::cout << "queries_proj: " << queries_proj << std::endl;
//std::cout << "keys_proj: " << keys_proj << std::endl;
//std::cout << "values_proj: " << values_proj << std::endl;
cudaStream_t stream = ctranslate2::cuda::get_cuda_stream();
run_mha_fwd(params, stream);
std::cout << "softmax_lse: " << softmax_lse << std::endl;
softmax_lse.release();


if (seqlenq_ngroups_swapped) {
transpose_op(context, fused_proj);
context = std::move(fused_proj);
context.reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
//softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}

transpose_op(context, fused_proj);
context = std::move(fused_proj);
//transpose_op(context, fused_proj);
//context = std::move(fused_proj);
//std::cout << "context: " << context << std::endl;

if (prefilling && cached_keys && cached_keys->shape()[2] > _sliding_window) {
if (prefilling && cached_keys && cached_keys->shape()[_cache_time_dim] > _sliding_window) {
// set only last sliding_window tokens to cached_keys and cached_values after computing attention
const ops::Slide slide_op(2, cached_keys->shape()[2] - _sliding_window, _sliding_window);
const ops::Slide slide_op(_cache_time_dim, cached_keys->shape()[_cache_time_dim] - _sliding_window, _sliding_window);
StorageView tmp(dtype, device);
slide_op(*cached_keys, tmp);
*cached_keys = std::move(tmp);
Expand Down
5 changes: 3 additions & 2 deletions src/ops/rotary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ namespace ctranslate2 {
void Rotary::operator()(const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const {
StorageView& output,
bool is_transposed) const {
PROFILE("Rotary");

output.resize_as(input);

DEVICE_AND_FLOAT_DISPATCH("Rotary", input.device(), input.dtype(),
(compute<D, T>(input, sin, cos, output)));
(compute<D, T>(input, sin, cos, output, is_transposed)));
}

}
Expand Down
8 changes: 5 additions & 3 deletions src/ops/rotary_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ namespace ctranslate2 {
void Rotary::compute(const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const {
const dim_t max_time = input.dim(-2);
StorageView& output,
bool is_transposed) const {
const dim_t max_time = is_transposed ? input.dim(-2) : input.dim(-3);
const dim_t depth = input.dim(-1);
const dim_t batch_size = input.size() / (max_time * depth);
const dim_t ndims = _ndims == 0 ? depth : _ndims;
Expand All @@ -65,7 +66,8 @@ namespace ctranslate2 {
Rotary::compute<Device::CPU, T>(const StorageView&, \
const StorageView&, \
const StorageView&, \
StorageView&) const;
StorageView&, \
bool) const;

DECLARE_IMPL(float)

Expand Down
19 changes: 12 additions & 7 deletions src/ops/rotary_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ namespace ctranslate2 {
const T* cos,
T* y,
const cuda::index_t max_time,
const cuda::index_t head_size,
const cuda::index_t ndims,
const cuda::index_t depth) {
const auto time = blockIdx.x % max_time;
const cuda::index_t depth,
const bool transpose) {
const auto time = transpose ? blockIdx.x % max_time : blockIdx.x / head_size;
const auto middle = ndims / 2;

x += blockIdx.x * depth;
Expand All @@ -57,8 +59,10 @@ namespace ctranslate2 {
void Rotary::compute(const StorageView& input,
const StorageView& sin,
const StorageView& cos,
StorageView& output) const {
const dim_t max_time = input.dim(-2);
StorageView& output,
bool is_transposed) const {
const dim_t max_time = is_transposed ? input.dim(-2) : input.dim(-3);
const dim_t head_size = is_transposed ? input.dim(-3) : input.dim(-2);
const dim_t depth = input.dim(-1);
const dim_t ndims = _ndims == 0 ? depth : _ndims;

Expand All @@ -74,18 +78,19 @@ namespace ctranslate2 {

if (_interleave)
rotary_kernel<DeviceT, true><<<blocks, threads, 0, cuda::get_cuda_stream()>>>(
x, s, c, y, max_time, ndims, depth);
x, s, c, y, max_time, head_size, ndims, depth, is_transposed);
else
rotary_kernel<DeviceT, false><<<blocks, threads, 0, cuda::get_cuda_stream()>>>(
x, s, c, y, max_time, ndims, depth);
x, s, c, y, max_time, head_size, ndims, depth, is_transposed);
}

#define DECLARE_IMPL(T) \
template void \
Rotary::compute<Device::CUDA, T>(const StorageView&, \
const StorageView&, \
const StorageView&, \
StorageView&) const;
StorageView&, \
bool) const;

DECLARE_IMPL(float)
DECLARE_IMPL(float16_t)
Expand Down

0 comments on commit b7338ce

Please sign in to comment.