@@ -90,6 +90,10 @@ namespace ctranslate2 {
90
90
const auto max_position_embeddings = model.get_attribute_with_default <int32_t >(
91
91
scope + " /max_position_embeddings" , 0 );
92
92
93
+ const auto rotary_high_freq_factor = model.get_attribute_with_default <float >(scope +
94
+ " /rotary_high_freq_factor" , 4.0 );
95
+ const auto rotary_low_freq_factor = model.get_attribute_with_default <float >(scope +
96
+ " /rotary_low_freq_factor" , 1.0 );
93
97
return std::make_unique<RotaryEmbeddings>(rotary_dim,
94
98
interleave,
95
99
scaling_type,
@@ -98,6 +102,8 @@ namespace ctranslate2 {
98
102
/* num_initial_positions*/ 2048 ,
99
103
rotary_long_factor,
100
104
rotary_short_factor,
105
+ rotary_low_freq_factor,
106
+ rotary_high_freq_factor,
101
107
original_max_position_embeddings,
102
108
max_position_embeddings,
103
109
transpose);
@@ -177,6 +183,8 @@ namespace ctranslate2 {
177
183
const dim_t num_initial_positions,
178
184
const StorageView* long_scaling_factor,
179
185
const StorageView* short_scaling_factor,
186
+ const float low_freq_factor,
187
+ const float high_freq_factor,
180
188
const dim_t original_max_position_embeddings,
181
189
const dim_t max_position_embeddings,
182
190
const bool transpose)
@@ -190,6 +198,8 @@ namespace ctranslate2 {
190
198
std::make_unique<StorageView>(*long_scaling_factor) : nullptr)
191
199
, _rotary_scaling_short_factor(short_scaling_factor ?
192
200
std::make_unique<StorageView>(*short_scaling_factor) : nullptr)
201
+ , _rotary_low_freq_factor(low_freq_factor)
202
+ , _rotary_high_freq_factor(high_freq_factor)
193
203
, _original_max_position_embeddings(original_max_position_embeddings)
194
204
, _max_position_embeddings(max_position_embeddings)
195
205
, _rotary_op(dim, interleave)
@@ -259,6 +269,30 @@ namespace ctranslate2 {
259
269
else {
260
270
for (dim_t i = 0 ; i < inv_freq.size (); ++i)
261
271
inv_freq.at <float >(i) = 1 .f / std::pow (_base, float (i * 2 ) / float (dim));
272
+ if (_scaling_type == RotaryScalingType::Llama3) {
273
+ StorageView new_freqs = inv_freq.sync_copy ();
274
+
275
+ const auto factor = _scaling_factor;
276
+ const float low_freq_factor = _rotary_low_freq_factor;
277
+ const float high_freq_factor = _rotary_high_freq_factor;
278
+ const auto old_context_len = static_cast < float >(_original_max_position_embeddings);
279
+
280
+ float low_freq_wavelen = old_context_len / low_freq_factor;
281
+ float high_freq_wavelen = old_context_len / high_freq_factor;
282
+ for (dim_t i = 0 ; i < inv_freq.size (); ++i) {
283
+ float wavelen = 2 .0f * M_PI / inv_freq.at <float >(i);
284
+ if (wavelen < high_freq_wavelen) {
285
+ // do nothing as we copied from inv_freq already.
286
+ } else if (wavelen > low_freq_wavelen) {
287
+ new_freqs.at <float >(i) /= factor;
288
+ } else {
289
+ float smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor);
290
+ auto freq = inv_freq.at <float >(i);
291
+ new_freqs.at <float >(i) = ((1 - smooth) * freq / factor + smooth * freq);
292
+ }
293
+ }
294
+ inv_freq = std::move (new_freqs);
295
+ }
262
296
}
263
297
if (inv_freq.device () != device)
264
298
inv_freq = inv_freq.to (device);
@@ -296,7 +330,7 @@ namespace ctranslate2 {
296
330
else
297
331
_cos = cos .to (dtype);
298
332
299
- if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0 ) {
333
+ if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0 && _scaling_type != RotaryScalingType::Llama3 ) {
300
334
StorageView scaling_factor;
301
335
float scale = _max_position_embeddings / _original_max_position_embeddings;
302
336
if (scale <= 1 )
0 commit comments