Skip to content

Commit

Permalink
support phi-3 128k (#1700)
Browse files Browse the repository at this point in the history
* support su rotary embedding

* fix black

* fix test rope embeddings

* fix flake

* fix tests

* small fix

* fix phi3 8k
  • Loading branch information
minhthuc2502 authored May 17, 2024
1 parent 580c685 commit a1eaa71
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 16 deletions.
9 changes: 9 additions & 0 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ namespace ctranslate2 {
enum class RotaryScalingType {
None = -1,
Linear,
Su,
};

class RotaryEmbeddings {
Expand All @@ -82,6 +83,10 @@ namespace ctranslate2 {
const float scaling_factor = 1,
const float base = 10000,
const dim_t num_initial_positions = 2048,
const StorageView* long_scaling_factor = nullptr,
const StorageView* short_scaling_factor = nullptr,
const dim_t original_max_position_embeddings = 0,
const dim_t max_position_embeddings = 0,
const bool transpose = true);

void apply(StorageView& x, const dim_t offset = 0, bool apply = true);
Expand Down Expand Up @@ -110,6 +115,10 @@ namespace ctranslate2 {
const float _scaling_factor;
const float _base;
const dim_t _num_initial_positions;
std::unique_ptr<StorageView> _rotary_scaling_long_factor;
std::unique_ptr<StorageView> _rotary_scaling_short_factor;
const dim_t _original_max_position_embeddings;
const dim_t _max_position_embeddings;
const ops::Rotary _rotary_op;
const bool _transpose;

Expand Down
52 changes: 42 additions & 10 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

_SUPPORTED_ROPE_SCALING = {
"linear": attention_spec.RotaryScalingType.Linear,
"su": attention_spec.RotaryScalingType.Su,
}

_MODEL_LOADERS = {}
Expand Down Expand Up @@ -346,9 +347,11 @@ def set_common_layers(self, spec, module):
spec.scale_embeddings = module.embed_scale
self.set_position_encodings(spec.position_encodings, module.embed_positions)
self.set_embeddings(
spec.embeddings[0]
if isinstance(spec.embeddings, list)
else spec.embeddings,
(
spec.embeddings[0]
if isinstance(spec.embeddings, list)
else spec.embeddings
),
module.embed_tokens,
)

Expand Down Expand Up @@ -1066,9 +1069,11 @@ def set_config(self, config, model, tokenizer):
def set_stack(self, spec, module, is_decoder=False):
self.set_layer_norm(spec.layer_norm, module.final_layer_norm)
self.set_embeddings(
spec.embeddings[0]
if isinstance(spec.embeddings, list)
else spec.embeddings,
(
spec.embeddings[0]
if isinstance(spec.embeddings, list)
else spec.embeddings
),
module.embed_tokens,
)

Expand Down Expand Up @@ -1298,9 +1303,11 @@ def get_model_spec(self, model):
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=common_spec.Activation.GELU
if activation_config == "gelu"
else common_spec.Activation.GELUTanh,
activation=(
common_spec.Activation.GELU
if activation_config == "gelu"
else common_spec.Activation.GELUTanh
),
pre_norm=True,
ffn_glu=True,
rms_norm=True,
Expand Down Expand Up @@ -1694,10 +1701,14 @@ def get_model_spec(self, model):
if num_heads_kv == num_heads:
num_heads_kv = None

original_max_position_embeddings = getattr(
model.config, "original_max_position_embeddings", 0
)
max_position_embeddings = getattr(model.config, "max_position_embeddings", 0)
rope_scaling = getattr(model.config, "rope_scaling", None)
if rope_scaling:
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
rotary_scaling_factor = rope_scaling["factor"]
rotary_scaling_factor = rope_scaling.get("factor", 1)

if rotary_scaling_type is None:
raise NotImplementedError(
Expand All @@ -1721,6 +1732,8 @@ def get_model_spec(self, model):
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=getattr(model.config, "rope_theta", 10000),
original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=max_position_embeddings,
num_heads_kv=num_heads_kv,
)

Expand Down Expand Up @@ -1748,6 +1761,16 @@ def set_config(self, config, model, tokenizer):
def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight

def set_rotary_embeddings(
self, spec, rotary_scaling_long_factor, rotary_scaling_short_factor
):
spec.rotary_scaling_long_factor = torch.tensor(
rotary_scaling_long_factor, dtype=torch.float32
)
spec.rotary_scaling_short_factor = torch.tensor(
rotary_scaling_short_factor, dtype=torch.float32
)

def set_decoder(self, spec, module):
spec.scale_embeddings = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
Expand All @@ -1765,6 +1788,15 @@ def set_decoder(self, spec, module):
layer_spec.self_attention.linear[0], layer.self_attn.qkv_proj
)
self.set_linear(layer_spec.self_attention.linear[1], layer.self_attn.o_proj)
if (
layer.self_attn.rotary_emb.long_factor is not None
and layer.self_attn.rotary_emb.short_factor is not None
):
self.set_rotary_embeddings(
layer_spec.self_attention,
layer.self_attn.rotary_emb.long_factor,
layer.self_attn.rotary_emb.short_factor,
)

gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
layer_spec.ffn.linear_0.weight = gate_proj
Expand Down
16 changes: 16 additions & 0 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class RotaryScalingType(enum.IntEnum):
"""RoPE scaling type."""

Linear = 0
Su = 1


class MultiHeadAttentionSpec(model_spec.LayerSpec):
Expand All @@ -24,6 +25,8 @@ def __init__(
rotary_scaling_type=None,
rotary_scaling_factor=1,
rotary_base=10000,
original_max_position_embeddings=0,
max_position_embeddings=0,
num_heads_kv=None,
head_dim=None,
sliding_window=None,
Expand All @@ -43,16 +46,29 @@ def __init__(
self.relative_attention_bias = None
self.relative_attention_max_distance = None

if original_max_position_embeddings != 0:
self.original_max_position_embeddings = np.dtype("int32").type(
original_max_position_embeddings
)
if max_position_embeddings != 0:
self.max_position_embeddings = np.dtype("int32").type(
max_position_embeddings
)

if rotary_dim is not None:
self.rotary_dim = np.dtype("int32").type(rotary_dim)
self.rotary_interleave = rotary_interleave
self.rotary_base = np.dtype("float32").type(rotary_base)

if rotary_scaling_type is not None:
self.rotary_scaling_type = np.dtype("int8").type(rotary_scaling_type)
if rotary_scaling_type is RotaryScalingType.Linear:
self.rotary_scaling_factor = np.dtype("float32").type(
rotary_scaling_factor
)
elif rotary_scaling_type is RotaryScalingType.Su:
self.rotary_scaling_long_factor = None
self.rotary_scaling_short_factor = None

if num_heads_kv is not None:
self.num_heads_kv = np.dtype("int32").type(num_heads_kv)
Expand Down
10 changes: 8 additions & 2 deletions python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
"float32",
)

SKIP_CREATING_ALIAS = ("rotary_scaling_long_factor", "rotary_scaling_short_factor")


def _join_scope(scope, name):
if not scope:
Expand Down Expand Up @@ -175,9 +177,13 @@ def _alias_variables(self):
break
# Because variables can be transformed on load (e.g. transposed),
# we use an element-wise equality check.
if not value.is_scalar() and value.equal(other_value):
scope, attr_name = _parent_scope(name)
if (
not value.is_scalar()
and value.equal(other_value)
and attr_name not in SKIP_CREATING_ALIAS
):
# Replace variable value by the alias name.
scope, attr_name = _parent_scope(name)
spec = index_spec(self, scope)
setattr(spec, attr_name, other_name)
break
Expand Down
18 changes: 18 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __init__(
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
rotary_scaling_factor: float = 1,
rotary_base: float = 10000,
original_max_position_embeddings: int = 0,
max_position_embeddings: int = 0,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
multi_query_attention: bool = False,
Expand Down Expand Up @@ -135,6 +137,9 @@ def __init__(
rotary_scaling_type: Type of RoPE scaling.
rotary_scaling_factor: Factor used in the RoPE scaling.
rotary_base: The base period of the rotary embeddings.
original_max_position_embeddings: The original max position embeddings
for Su rope embeddings
max_position_embeddings: The max position embeddings for Su rope embeddings
parallel_residual: Use parallel residual connections in each layer block, as used
by the GPT-J and GPT-NeoX models.
shared_layer_norm: When using parallel residual, share the input and post
Expand Down Expand Up @@ -199,6 +204,8 @@ def __init__(
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=max_position_embeddings,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
num_heads_kv=num_heads_kv,
Expand Down Expand Up @@ -251,6 +258,8 @@ def __init__(
rotary_scaling_type=None,
rotary_scaling_factor=1,
rotary_base=10000,
original_max_position_embeddings=0,
max_position_embeddings=0,
parallel_residual=False,
shared_layer_norm=False,
num_heads_kv=None,
Expand All @@ -267,6 +276,8 @@ def __init__(
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=max_position_embeddings,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
Expand Down Expand Up @@ -499,6 +510,8 @@ def from_config(
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
rotary_scaling_factor: float = 1,
rotary_base: float = 10000,
original_max_position_embeddings: int = 0,
max_position_embeddings: int = 0,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
multi_query_attention: bool = False,
Expand Down Expand Up @@ -531,6 +544,9 @@ def from_config(
rotary_scaling_type: Type of RoPE scaling.
rotary_scaling_factor: Factor used in the RoPE scaling.
rotary_base: The base period of the rotary embeddings.
original_max_position_embeddings: The original max position embeddings
for Su rope embeddings
max_position_embeddings: The max position embeddings for Su rope embeddings
parallel_residual: Use parallel residual connections in each layer block, as used
by the GPT-J and GPT-NeoX models.
shared_layer_norm: When using parallel residual, share the input and post
Expand Down Expand Up @@ -559,6 +575,8 @@ def from_config(
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=max_position_embeddings,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
multi_query_attention=multi_query_attention,
Expand Down
Loading

0 comments on commit a1eaa71

Please sign in to comment.