Skip to content

Commit 8429e5e

Browse files
committed
fix: implement llama3 RoPE scaling type and fix converter
1 parent e6a8f94 commit 8429e5e

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

Diff for: include/ctranslate2/layers/attention_layer.h

+5
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ namespace ctranslate2 {
7373
None = -1,
7474
Linear,
7575
Su,
76+
Llama3,
7677
};
7778

7879
class RotaryEmbeddings {
@@ -85,6 +86,8 @@ namespace ctranslate2 {
8586
const dim_t num_initial_positions = 2048,
8687
const StorageView* long_scaling_factor = nullptr,
8788
const StorageView* short_scaling_factor = nullptr,
89+
const float low_freq_factor = 1.0,
90+
const float high_freq_factor = 4.0,
8891
const dim_t original_max_position_embeddings = 0,
8992
const dim_t max_position_embeddings = 0,
9093
const bool transpose = true);
@@ -117,6 +120,8 @@ namespace ctranslate2 {
117120
const dim_t _num_initial_positions;
118121
std::unique_ptr<StorageView> _rotary_scaling_long_factor;
119122
std::unique_ptr<StorageView> _rotary_scaling_short_factor;
123+
const float _rotary_low_freq_factor;
124+
const float _rotary_high_freq_factor;
120125
const dim_t _original_max_position_embeddings;
121126
const dim_t _max_position_embeddings;
122127
const ops::Rotary _rotary_op;

Diff for: python/ctranslate2/converters/transformers.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_SUPPORTED_ROPE_SCALING = {
4242
"linear": attention_spec.RotaryScalingType.Linear,
4343
"su": attention_spec.RotaryScalingType.Su,
44+
"llama3": attention_spec.RotaryScalingType.Llama3,
4445
}
4546

4647
_SUPPORTED_QUANTIZATION = {
@@ -1405,7 +1406,8 @@ def get_model_spec(self, model):
14051406

14061407
rope_scaling = getattr(model.config, "rope_scaling", None)
14071408
if rope_scaling:
1408-
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
1409+
rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
1410+
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
14091411
rotary_scaling_factor = rope_scaling["factor"]
14101412

14111413
if rotary_scaling_type is None:
@@ -1420,6 +1422,7 @@ def get_model_spec(self, model):
14201422

14211423
quantization_config = getattr(model.config, "quantization_config", None)
14221424
if quantization_config:
1425+
quant_type = None
14231426
if quantization_config.quant_method == "awq":
14241427
quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
14251428
if quant_type is None:
@@ -1458,6 +1461,16 @@ def get_model_spec(self, model):
14581461

14591462
self.set_decoder(spec.decoder, model.model, quant_type)
14601463
self.set_linear(spec.decoder.projection, model.lm_head)
1464+
1465+
# set extra RoPE parameters for Llama-3.1
1466+
if rotary_scaling_type == attention_spec.RotaryScalingType.Llama3:
1467+
for layer in spec.decoder.layer:
1468+
layer.self_attention.rotary_low_freq_factor = rope_scaling[
1469+
"low_freq_factor"
1470+
]
1471+
layer.self_attention.rotary_high_freq_factor = rope_scaling[
1472+
"high_freq_factor"
1473+
]
14611474
return spec
14621475

14631476
def get_vocabulary(self, model, tokenizer):

Diff for: python/ctranslate2/specs/attention_spec.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class RotaryScalingType(enum.IntEnum):
1111

1212
Linear = 0
1313
Su = 1
14+
Llama3 = 2
1415

1516

1617
class MultiHeadAttentionSpec(model_spec.LayerSpec):
@@ -69,6 +70,9 @@ def __init__(
6970
elif rotary_scaling_type is RotaryScalingType.Su:
7071
self.rotary_scaling_long_factor = None
7172
self.rotary_scaling_short_factor = None
73+
elif rotary_scaling_type is RotaryScalingType.Llama3:
74+
self.rotary_low_freq_factor = None
75+
self.rotary_high_freq_factor = None
7276

7377
if num_heads_kv is not None:
7478
self.num_heads_kv = np.dtype("int32").type(num_heads_kv)

Diff for: src/layers/attention_layer.cc

+35-1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ namespace ctranslate2 {
9090
const auto max_position_embeddings = model.get_attribute_with_default<int32_t>(
9191
scope + "/max_position_embeddings", 0);
9292

93+
const auto rotary_high_freq_factor = model.get_attribute_with_default<float>(scope +
94+
"/rotary_high_freq_factor", 4.0);
95+
const auto rotary_low_freq_factor = model.get_attribute_with_default<float>(scope +
96+
"/rotary_low_freq_factor", 1.0);
9397
return std::make_unique<RotaryEmbeddings>(rotary_dim,
9498
interleave,
9599
scaling_type,
@@ -98,6 +102,8 @@ namespace ctranslate2 {
98102
/*num_initial_positions*/2048,
99103
rotary_long_factor,
100104
rotary_short_factor,
105+
rotary_low_freq_factor,
106+
rotary_high_freq_factor,
101107
original_max_position_embeddings,
102108
max_position_embeddings,
103109
transpose);
@@ -177,6 +183,8 @@ namespace ctranslate2 {
177183
const dim_t num_initial_positions,
178184
const StorageView* long_scaling_factor,
179185
const StorageView* short_scaling_factor,
186+
const float low_freq_factor,
187+
const float high_freq_factor,
180188
const dim_t original_max_position_embeddings,
181189
const dim_t max_position_embeddings,
182190
const bool transpose)
@@ -190,6 +198,8 @@ namespace ctranslate2 {
190198
std::make_unique<StorageView>(*long_scaling_factor) : nullptr)
191199
, _rotary_scaling_short_factor(short_scaling_factor ?
192200
std::make_unique<StorageView>(*short_scaling_factor) : nullptr)
201+
, _rotary_low_freq_factor(low_freq_factor)
202+
, _rotary_high_freq_factor(high_freq_factor)
193203
, _original_max_position_embeddings(original_max_position_embeddings)
194204
, _max_position_embeddings(max_position_embeddings)
195205
, _rotary_op(dim, interleave)
@@ -259,6 +269,30 @@ namespace ctranslate2 {
259269
else {
260270
for (dim_t i = 0; i < inv_freq.size(); ++i)
261271
inv_freq.at<float>(i) = 1.f / std::pow(_base, float(i * 2) / float(dim));
272+
if (_scaling_type == RotaryScalingType::Llama3) {
273+
StorageView new_freqs = inv_freq.sync_copy();
274+
275+
const auto factor = _scaling_factor;
276+
const float low_freq_factor = _rotary_low_freq_factor;
277+
const float high_freq_factor = _rotary_high_freq_factor;
278+
const auto old_context_len = static_cast< float >(_original_max_position_embeddings);
279+
280+
float low_freq_wavelen = old_context_len / low_freq_factor;
281+
float high_freq_wavelen = old_context_len / high_freq_factor;
282+
for (dim_t i = 0; i < inv_freq.size(); ++i) {
283+
float wavelen = 2.0f * M_PI / inv_freq.at<float>(i);
284+
if (wavelen < high_freq_wavelen) {
285+
// do nothing as we copied from inv_freq already.
286+
} else if (wavelen > low_freq_wavelen) {
287+
new_freqs.at<float>(i) /= factor;
288+
} else {
289+
float smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor);
290+
auto freq = inv_freq.at<float>(i);
291+
new_freqs.at<float>(i) = ((1 - smooth) * freq / factor + smooth * freq);
292+
}
293+
}
294+
inv_freq = std::move(new_freqs);
295+
}
262296
}
263297
if (inv_freq.device() != device)
264298
inv_freq = inv_freq.to(device);
@@ -296,7 +330,7 @@ namespace ctranslate2 {
296330
else
297331
_cos = cos.to(dtype);
298332

299-
if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0) {
333+
if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0 && _scaling_type != RotaryScalingType::Llama3) {
300334
StorageView scaling_factor;
301335
float scale = _max_position_embeddings / _original_max_position_embeddings;
302336
if (scale <= 1)

0 commit comments

Comments
 (0)