diff --git a/include/ctranslate2/layers/attention_layer.h b/include/ctranslate2/layers/attention_layer.h index daa9206c5..bea2e1e42 100644 --- a/include/ctranslate2/layers/attention_layer.h +++ b/include/ctranslate2/layers/attention_layer.h @@ -72,6 +72,7 @@ namespace ctranslate2 { enum class RotaryScalingType { None = -1, Linear, + Su, }; class RotaryEmbeddings { @@ -82,6 +83,10 @@ namespace ctranslate2 { const float scaling_factor = 1, const float base = 10000, const dim_t num_initial_positions = 2048, + const StorageView* long_scaling_factor = nullptr, + const StorageView* short_scaling_factor = nullptr, + const dim_t original_max_position_embeddings = 0, + const dim_t max_position_embeddings = 0, const bool transpose = true); void apply(StorageView& x, const dim_t offset = 0, bool apply = true); @@ -110,6 +115,10 @@ namespace ctranslate2 { const float _scaling_factor; const float _base; const dim_t _num_initial_positions; + std::unique_ptr _rotary_scaling_long_factor; + std::unique_ptr _rotary_scaling_short_factor; + const dim_t _original_max_position_embeddings; + const dim_t _max_position_embeddings; const ops::Rotary _rotary_op; const bool _transpose; diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 84bf18a71..719983a3d 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -40,6 +40,7 @@ _SUPPORTED_ROPE_SCALING = { "linear": attention_spec.RotaryScalingType.Linear, + "su": attention_spec.RotaryScalingType.Su, } _MODEL_LOADERS = {} @@ -346,9 +347,11 @@ def set_common_layers(self, spec, module): spec.scale_embeddings = module.embed_scale self.set_position_encodings(spec.position_encodings, module.embed_positions) self.set_embeddings( - spec.embeddings[0] - if isinstance(spec.embeddings, list) - else spec.embeddings, + ( + spec.embeddings[0] + if isinstance(spec.embeddings, list) + else spec.embeddings + ), module.embed_tokens, ) @@ -1066,9 +1069,11 @@ def set_config(self, config, model, tokenizer): def set_stack(self, spec, module, is_decoder=False): self.set_layer_norm(spec.layer_norm, module.final_layer_norm) self.set_embeddings( - spec.embeddings[0] - if isinstance(spec.embeddings, list) - else spec.embeddings, + ( + spec.embeddings[0] + if isinstance(spec.embeddings, list) + else spec.embeddings + ), module.embed_tokens, ) @@ -1298,9 +1303,11 @@ def get_model_spec(self, model): spec = transformer_spec.TransformerDecoderModelSpec.from_config( num_layers, num_heads, - activation=common_spec.Activation.GELU - if activation_config == "gelu" - else common_spec.Activation.GELUTanh, + activation=( + common_spec.Activation.GELU + if activation_config == "gelu" + else common_spec.Activation.GELUTanh + ), pre_norm=True, ffn_glu=True, rms_norm=True, @@ -1694,10 +1701,14 @@ def get_model_spec(self, model): if num_heads_kv == num_heads: num_heads_kv = None + original_max_position_embeddings = getattr( + model.config, "original_max_position_embeddings", 0 + ) + max_position_embeddings = getattr(model.config, "max_position_embeddings", 0) rope_scaling = getattr(model.config, "rope_scaling", None) if rope_scaling: rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"]) - rotary_scaling_factor = rope_scaling["factor"] + rotary_scaling_factor = rope_scaling.get("factor", 1) if rotary_scaling_type is None: raise NotImplementedError( @@ -1721,6 +1732,8 @@ def get_model_spec(self, model): rotary_scaling_type=rotary_scaling_type, rotary_scaling_factor=rotary_scaling_factor, rotary_base=getattr(model.config, "rope_theta", 10000), + original_max_position_embeddings=original_max_position_embeddings, + max_position_embeddings=max_position_embeddings, num_heads_kv=num_heads_kv, ) @@ -1748,6 +1761,16 @@ def set_config(self, config, model, tokenizer): def set_layer_norm(self, spec, layer_norm): spec.gamma = layer_norm.weight + def set_rotary_embeddings( + self, spec, rotary_scaling_long_factor, rotary_scaling_short_factor + ): + spec.rotary_scaling_long_factor = torch.tensor( + rotary_scaling_long_factor, dtype=torch.float32 + ) + spec.rotary_scaling_short_factor = torch.tensor( + rotary_scaling_short_factor, dtype=torch.float32 + ) + def set_decoder(self, spec, module): spec.scale_embeddings = False self.set_embeddings(spec.embeddings, module.embed_tokens) @@ -1765,6 +1788,15 @@ def set_decoder(self, spec, module): layer_spec.self_attention.linear[0], layer.self_attn.qkv_proj ) self.set_linear(layer_spec.self_attention.linear[1], layer.self_attn.o_proj) + if ( + layer.self_attn.rotary_emb.long_factor is not None + and layer.self_attn.rotary_emb.short_factor is not None + ): + self.set_rotary_embeddings( + layer_spec.self_attention, + layer.self_attn.rotary_emb.long_factor, + layer.self_attn.rotary_emb.short_factor, + ) gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0) layer_spec.ffn.linear_0.weight = gate_proj diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index 14ef17d07..88eabb7af 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -10,6 +10,7 @@ class RotaryScalingType(enum.IntEnum): """RoPE scaling type.""" Linear = 0 + Su = 1 class MultiHeadAttentionSpec(model_spec.LayerSpec): @@ -24,6 +25,8 @@ def __init__( rotary_scaling_type=None, rotary_scaling_factor=1, rotary_base=10000, + original_max_position_embeddings=0, + max_position_embeddings=0, num_heads_kv=None, head_dim=None, sliding_window=None, @@ -43,6 +46,15 @@ def __init__( self.relative_attention_bias = None self.relative_attention_max_distance = None + if original_max_position_embeddings != 0: + self.original_max_position_embeddings = np.dtype("int32").type( + original_max_position_embeddings + ) + if max_position_embeddings != 0: + self.max_position_embeddings = np.dtype("int32").type( + max_position_embeddings + ) + if rotary_dim is not None: self.rotary_dim = np.dtype("int32").type(rotary_dim) self.rotary_interleave = rotary_interleave @@ -50,9 +62,13 @@ def __init__( if rotary_scaling_type is not None: self.rotary_scaling_type = np.dtype("int8").type(rotary_scaling_type) + if rotary_scaling_type is RotaryScalingType.Linear: self.rotary_scaling_factor = np.dtype("float32").type( rotary_scaling_factor ) + elif rotary_scaling_type is RotaryScalingType.Su: + self.rotary_scaling_long_factor = None + self.rotary_scaling_short_factor = None if num_heads_kv is not None: self.num_heads_kv = np.dtype("int32").type(num_heads_kv) diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index c1e08f7c8..41710ff41 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -35,6 +35,8 @@ "float32", ) +SKIP_CREATING_ALIAS = ("rotary_scaling_long_factor", "rotary_scaling_short_factor") + def _join_scope(scope, name): if not scope: @@ -175,9 +177,13 @@ def _alias_variables(self): break # Because variables can be transformed on load (e.g. transposed), # we use an element-wise equality check. - if not value.is_scalar() and value.equal(other_value): + scope, attr_name = _parent_scope(name) + if ( + not value.is_scalar() + and value.equal(other_value) + and attr_name not in SKIP_CREATING_ALIAS + ): # Replace variable value by the alias name. - scope, attr_name = _parent_scope(name) spec = index_spec(self, scope) setattr(spec, attr_name, other_name) break diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index c3f8d91be..2325f7bbf 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -97,6 +97,8 @@ def __init__( rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None, rotary_scaling_factor: float = 1, rotary_base: float = 10000, + original_max_position_embeddings: int = 0, + max_position_embeddings: int = 0, parallel_residual: bool = False, shared_layer_norm: bool = False, multi_query_attention: bool = False, @@ -135,6 +137,9 @@ def __init__( rotary_scaling_type: Type of RoPE scaling. rotary_scaling_factor: Factor used in the RoPE scaling. rotary_base: The base period of the rotary embeddings. + original_max_position_embeddings: The original max position embeddings + for Su rope embeddings + max_position_embeddings: The max position embeddings for Su rope embeddings parallel_residual: Use parallel residual connections in each layer block, as used by the GPT-J and GPT-NeoX models. shared_layer_norm: When using parallel residual, share the input and post @@ -199,6 +204,8 @@ def __init__( rotary_scaling_type=rotary_scaling_type, rotary_scaling_factor=rotary_scaling_factor, rotary_base=rotary_base, + original_max_position_embeddings=original_max_position_embeddings, + max_position_embeddings=max_position_embeddings, parallel_residual=parallel_residual, shared_layer_norm=shared_layer_norm, num_heads_kv=num_heads_kv, @@ -251,6 +258,8 @@ def __init__( rotary_scaling_type=None, rotary_scaling_factor=1, rotary_base=10000, + original_max_position_embeddings=0, + max_position_embeddings=0, parallel_residual=False, shared_layer_norm=False, num_heads_kv=None, @@ -267,6 +276,8 @@ def __init__( rotary_scaling_type=rotary_scaling_type, rotary_scaling_factor=rotary_scaling_factor, rotary_base=rotary_base, + original_max_position_embeddings=original_max_position_embeddings, + max_position_embeddings=max_position_embeddings, num_heads_kv=num_heads_kv, head_dim=head_dim, sliding_window=sliding_window, @@ -499,6 +510,8 @@ def from_config( rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None, rotary_scaling_factor: float = 1, rotary_base: float = 10000, + original_max_position_embeddings: int = 0, + max_position_embeddings: int = 0, parallel_residual: bool = False, shared_layer_norm: bool = False, multi_query_attention: bool = False, @@ -531,6 +544,9 @@ def from_config( rotary_scaling_type: Type of RoPE scaling. rotary_scaling_factor: Factor used in the RoPE scaling. rotary_base: The base period of the rotary embeddings. + original_max_position_embeddings: The original max position embeddings + for Su rope embeddings + max_position_embeddings: The max position embeddings for Su rope embeddings parallel_residual: Use parallel residual connections in each layer block, as used by the GPT-J and GPT-NeoX models. shared_layer_norm: When using parallel residual, share the input and post @@ -559,6 +575,8 @@ def from_config( rotary_scaling_type=rotary_scaling_type, rotary_scaling_factor=rotary_scaling_factor, rotary_base=rotary_base, + original_max_position_embeddings=original_max_position_embeddings, + max_position_embeddings=max_position_embeddings, parallel_residual=parallel_residual, shared_layer_norm=shared_layer_norm, multi_query_attention=multi_query_attention, diff --git a/src/layers/attention_layer.cc b/src/layers/attention_layer.cc index 18b4fa16b..0a85d002a 100644 --- a/src/layers/attention_layer.cc +++ b/src/layers/attention_layer.cc @@ -80,6 +80,15 @@ namespace ctranslate2 { scope + "/rotary_scaling_type", -1); const auto scaling_factor = model.get_attribute_with_default( scope + "/rotary_scaling_factor", 1.f); + const auto rotary_long_factor = model.get_variable_if_exists(scope + + "/rotary_scaling_long_factor"); + const auto rotary_short_factor = model.get_variable_if_exists(scope + + "/rotary_scaling_short_factor"); + const auto original_max_position_embeddings = model.get_attribute_with_default( + scope + "/original_max_position_embeddings", 0); + + const auto max_position_embeddings = model.get_attribute_with_default( + scope + "/max_position_embeddings", 0); return std::make_unique(rotary_dim, interleave, @@ -87,6 +96,10 @@ namespace ctranslate2 { scaling_factor, base, /*num_initial_positions*/2048, + rotary_long_factor, + rotary_short_factor, + original_max_position_embeddings, + max_position_embeddings, transpose); } @@ -162,6 +175,10 @@ namespace ctranslate2 { const float scaling_factor, const float base, const dim_t num_initial_positions, + const StorageView* long_scaling_factor, + const StorageView* short_scaling_factor, + const dim_t original_max_position_embeddings, + const dim_t max_position_embeddings, const bool transpose) : _dim(dim) , _interleave(interleave) @@ -169,9 +186,19 @@ namespace ctranslate2 { , _scaling_factor(scaling_factor) , _base(base) , _num_initial_positions(num_initial_positions) + , _rotary_scaling_long_factor(long_scaling_factor ? + std::make_unique(*long_scaling_factor) : nullptr) + , _rotary_scaling_short_factor(short_scaling_factor ? + std::make_unique(*short_scaling_factor) : nullptr) + , _original_max_position_embeddings(original_max_position_embeddings) + , _max_position_embeddings(max_position_embeddings) , _rotary_op(dim, interleave) , _transpose(transpose) { + if (_rotary_scaling_long_factor && _rotary_scaling_long_factor->device() != Device::CPU) + _rotary_scaling_long_factor = std::make_unique(_rotary_scaling_long_factor->to(Device::CPU)); + if (_rotary_scaling_short_factor && _rotary_scaling_short_factor->device() != Device::CPU) + _rotary_scaling_short_factor = std::make_unique(_rotary_scaling_short_factor->to(Device::CPU)); } void RotaryEmbeddings::apply(StorageView& x, const dim_t offset, bool apply) { @@ -206,14 +233,27 @@ namespace ctranslate2 { const Device device, const DataType dtype) { StorageView inv_freq({1, dim / 2}); - for (dim_t i = 0; i < inv_freq.size(); ++i) - inv_freq.at(i) = 1.f / std::pow(_base, float(i * 2) / float(dim)); + if (_scaling_type == RotaryScalingType::Su) { + StorageView* scaling_factor; + for (dim_t i = 0; i < inv_freq.size(); ++i) { + if (num_positions > _original_max_position_embeddings) + scaling_factor = _rotary_scaling_long_factor.get(); + else + scaling_factor = _rotary_scaling_short_factor.get(); + inv_freq.at(i) = 1.f / (scaling_factor->at(i) * + (std::pow(_base, float(i * 2) / float(dim)))); + } + } + else { + for (dim_t i = 0; i < inv_freq.size(); ++i) + inv_freq.at(i) = 1.f / std::pow(_base, float(i * 2) / float(dim)); + } if (inv_freq.device() != device) inv_freq = inv_freq.to(device); StorageView t({num_positions, 1}); for (dim_t i = 0; i < t.size(); ++i) - t.at(i) = _scaling_type == RotaryScalingType::None ? i : float(i) / _scaling_factor; + t.at(i) = _scaling_type != RotaryScalingType::Linear ? i : float(i) / _scaling_factor; if (t.device() != device) t = t.to(device); @@ -226,8 +266,9 @@ namespace ctranslate2 { StorageView emb(device); ops::Concat(-1)({&freqs, &freqs}, emb); - if (_interleave) + if (_interleave) { emb.reshape({num_positions, dim}); + } StorageView sin(device); ops::Sin()(emb, sin); @@ -242,6 +283,18 @@ namespace ctranslate2 { _cos = std::move(cos); else _cos = cos.to(dtype); + + if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0) { + StorageView scaling_factor; + float scale = _max_position_embeddings / _original_max_position_embeddings; + if (scale <= 1) + scaling_factor = StorageView(1.0f, device); + else + scaling_factor = StorageView(static_cast(sqrt(1 + std::log(scale) / std::log(_original_max_position_embeddings)))); + + ops::Mul()(_sin, scaling_factor, _sin); + ops::Mul()(_cos, scaling_factor, _cos); + } }