Skip to content

Commit

Permalink
fix: implement llama3 RoPE scaling type and fix converter (#1751)
Browse files Browse the repository at this point in the history
* fix: implement llama3 RoPE scaling type and fix converter

* build: add definition for M_PI on windows
  • Loading branch information
ebraraktas authored Aug 12, 2024
1 parent e6a8f94 commit a386cbd
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 2 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ string(REPLACE "." ";" CTRANSLATE2_VERSION_LIST ${CTRANSLATE2_VERSION})
list(GET CTRANSLATE2_VERSION_LIST 0 CTRANSLATE2_MAJOR_VERSION)

if(MSVC)
add_compile_definitions(_USE_MATH_DEFINES) # required for M_PI
if(BUILD_SHARED_LIBS)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
else()
Expand Down
5 changes: 5 additions & 0 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ namespace ctranslate2 {
None = -1,
Linear,
Su,
Llama3,
};

class RotaryEmbeddings {
Expand All @@ -85,6 +86,8 @@ namespace ctranslate2 {
const dim_t num_initial_positions = 2048,
const StorageView* long_scaling_factor = nullptr,
const StorageView* short_scaling_factor = nullptr,
const float low_freq_factor = 1.0,
const float high_freq_factor = 4.0,
const dim_t original_max_position_embeddings = 0,
const dim_t max_position_embeddings = 0,
const bool transpose = true);
Expand Down Expand Up @@ -117,6 +120,8 @@ namespace ctranslate2 {
const dim_t _num_initial_positions;
std::unique_ptr<StorageView> _rotary_scaling_long_factor;
std::unique_ptr<StorageView> _rotary_scaling_short_factor;
const float _rotary_low_freq_factor;
const float _rotary_high_freq_factor;
const dim_t _original_max_position_embeddings;
const dim_t _max_position_embeddings;
const ops::Rotary _rotary_op;
Expand Down
15 changes: 14 additions & 1 deletion python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_SUPPORTED_ROPE_SCALING = {
"linear": attention_spec.RotaryScalingType.Linear,
"su": attention_spec.RotaryScalingType.Su,
"llama3": attention_spec.RotaryScalingType.Llama3,
}

_SUPPORTED_QUANTIZATION = {
Expand Down Expand Up @@ -1405,7 +1406,8 @@ def get_model_spec(self, model):

rope_scaling = getattr(model.config, "rope_scaling", None)
if rope_scaling:
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
rotary_scaling_factor = rope_scaling["factor"]

if rotary_scaling_type is None:
Expand All @@ -1420,6 +1422,7 @@ def get_model_spec(self, model):

quantization_config = getattr(model.config, "quantization_config", None)
if quantization_config:
quant_type = None
if quantization_config.quant_method == "awq":
quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
if quant_type is None:
Expand Down Expand Up @@ -1458,6 +1461,16 @@ def get_model_spec(self, model):

self.set_decoder(spec.decoder, model.model, quant_type)
self.set_linear(spec.decoder.projection, model.lm_head)

# set extra RoPE parameters for Llama-3.1
if rotary_scaling_type == attention_spec.RotaryScalingType.Llama3:
for layer in spec.decoder.layer:
layer.self_attention.rotary_low_freq_factor = rope_scaling[
"low_freq_factor"
]
layer.self_attention.rotary_high_freq_factor = rope_scaling[
"high_freq_factor"
]
return spec

def get_vocabulary(self, model, tokenizer):
Expand Down
4 changes: 4 additions & 0 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class RotaryScalingType(enum.IntEnum):

Linear = 0
Su = 1
Llama3 = 2


class MultiHeadAttentionSpec(model_spec.LayerSpec):
Expand Down Expand Up @@ -69,6 +70,9 @@ def __init__(
elif rotary_scaling_type is RotaryScalingType.Su:
self.rotary_scaling_long_factor = None
self.rotary_scaling_short_factor = None
elif rotary_scaling_type is RotaryScalingType.Llama3:
self.rotary_low_freq_factor = None
self.rotary_high_freq_factor = None

if num_heads_kv is not None:
self.num_heads_kv = np.dtype("int32").type(num_heads_kv)
Expand Down
36 changes: 35 additions & 1 deletion src/layers/attention_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ namespace ctranslate2 {
const auto max_position_embeddings = model.get_attribute_with_default<int32_t>(
scope + "/max_position_embeddings", 0);

const auto rotary_high_freq_factor = model.get_attribute_with_default<float>(scope +
"/rotary_high_freq_factor", 4.0);
const auto rotary_low_freq_factor = model.get_attribute_with_default<float>(scope +
"/rotary_low_freq_factor", 1.0);
return std::make_unique<RotaryEmbeddings>(rotary_dim,
interleave,
scaling_type,
Expand All @@ -98,6 +102,8 @@ namespace ctranslate2 {
/*num_initial_positions*/2048,
rotary_long_factor,
rotary_short_factor,
rotary_low_freq_factor,
rotary_high_freq_factor,
original_max_position_embeddings,
max_position_embeddings,
transpose);
Expand Down Expand Up @@ -177,6 +183,8 @@ namespace ctranslate2 {
const dim_t num_initial_positions,
const StorageView* long_scaling_factor,
const StorageView* short_scaling_factor,
const float low_freq_factor,
const float high_freq_factor,
const dim_t original_max_position_embeddings,
const dim_t max_position_embeddings,
const bool transpose)
Expand All @@ -190,6 +198,8 @@ namespace ctranslate2 {
std::make_unique<StorageView>(*long_scaling_factor) : nullptr)
, _rotary_scaling_short_factor(short_scaling_factor ?
std::make_unique<StorageView>(*short_scaling_factor) : nullptr)
, _rotary_low_freq_factor(low_freq_factor)
, _rotary_high_freq_factor(high_freq_factor)
, _original_max_position_embeddings(original_max_position_embeddings)
, _max_position_embeddings(max_position_embeddings)
, _rotary_op(dim, interleave)
Expand Down Expand Up @@ -259,6 +269,30 @@ namespace ctranslate2 {
else {
for (dim_t i = 0; i < inv_freq.size(); ++i)
inv_freq.at<float>(i) = 1.f / std::pow(_base, float(i * 2) / float(dim));
if (_scaling_type == RotaryScalingType::Llama3) {
StorageView new_freqs = inv_freq.sync_copy();

const auto factor = _scaling_factor;
const float low_freq_factor = _rotary_low_freq_factor;
const float high_freq_factor = _rotary_high_freq_factor;
const auto old_context_len = static_cast< float >(_original_max_position_embeddings);

float low_freq_wavelen = old_context_len / low_freq_factor;
float high_freq_wavelen = old_context_len / high_freq_factor;
for (dim_t i = 0; i < inv_freq.size(); ++i) {
float wavelen = 2.0f * M_PI / inv_freq.at<float>(i);
if (wavelen < high_freq_wavelen) {
// do nothing as we copied from inv_freq already.
} else if (wavelen > low_freq_wavelen) {
new_freqs.at<float>(i) /= factor;
} else {
float smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor);
auto freq = inv_freq.at<float>(i);
new_freqs.at<float>(i) = ((1 - smooth) * freq / factor + smooth * freq);
}
}
inv_freq = std::move(new_freqs);
}
}
if (inv_freq.device() != device)
inv_freq = inv_freq.to(device);
Expand Down Expand Up @@ -296,7 +330,7 @@ namespace ctranslate2 {
else
_cos = cos.to(dtype);

if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0) {
if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0 && _scaling_type != RotaryScalingType::Llama3) {
StorageView scaling_factor;
float scale = _max_position_embeddings / _original_max_position_embeddings;
if (scale <= 1)
Expand Down

0 comments on commit a386cbd

Please sign in to comment.