Skip to content

Commit

Permalink
sync with the main repository
Browse files Browse the repository at this point in the history
  • Loading branch information
hkwon committed Sep 11, 2024
1 parent f79a752 commit 8e0ce67
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 10 deletions.
9 changes: 0 additions & 9 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,6 @@ namespace ctranslate2 {
dim_t left_max_position,
dim_t right_max_position);

/*StorageView make_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t max_position);
StorageView make_relative_asymmetric_positions(dim_t queries_length,
dim_t keys_length,
dim_t left_max_position,
dim_t right_max_position);*/

class RotaryEmbeddings;
class Alibi;

Expand Down
104 changes: 104 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,110 @@ def set_decoder(self, spec, module):
gc.collect()


@register_loader("Gemma2Config")
class Gemma2Loader(ModelLoader):
@property
def architecture_name(self):
return "Gemma2ForCausalLM"

def get_model_spec(self, model):
num_layers = model.config.num_hidden_layers

num_heads = model.config.num_attention_heads
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None

activation_config = getattr(
model.config, "hidden_activation", "gelu_pytorch_tanh"
)

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=(
common_spec.Activation.GELU
if activation_config == "gelu"
else common_spec.Activation.GELUTanh
),
pre_norm=True,
ffn_glu=True,
rms_norm=True,
rotary_dim=0,
rotary_interleave=False,
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
head_dim=model.config.head_dim,
pre_post_layer_norm=True,
)

self.set_decoder(spec.decoder, model.model)
self.set_linear(spec.decoder.projection, model.lm_head)
spec.decoder.embeddings.multiply_by_sqrt_depth = model.config.hidden_size**0.5
return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)
if model.config.vocab_size < len(tokens):
tokens = tokens[: model.config.vocab_size]

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token
config.layer_norm_epsilon = model.config.rms_norm_eps

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight
spec.layer_norm_use_residual = True

def set_decoder(self, spec, module):
spec.scale_embeddings = True
spec.start_from_zero_embedding = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)

for layer_spec, layer in zip(spec.layer, module.layers):
self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)

self.set_layer_norm(
layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
)

self.set_layer_norm(
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
)

self.set_layer_norm(
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
)

wq = layer.self_attn.q_proj.weight
wk = layer.self_attn.k_proj.weight
wv = layer.self_attn.v_proj.weight
wo = layer.self_attn.o_proj.weight

layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
layer_spec.self_attention.linear[1].weight = wo

self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("LlamaConfig")
class LlamaLoader(ModelLoader):
@property
Expand Down
2 changes: 1 addition & 1 deletion src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ namespace ctranslate2 {
if (queries_padder)
queries_padder->add_padding(fused_proj);

const ops::Split split_op(2, {_d_model, _num_heads_kv * _d_head, _num_heads_kv * _d_head});
const ops::Split split_op(2, {_num_heads * _d_head, _num_heads_kv * _d_head, _num_heads_kv * _d_head});
split_op(fused_proj, queries_proj, keys_proj, values_proj);

if (_merge_time_and_head_dims) {
Expand Down

0 comments on commit 8e0ce67

Please sign in to comment.