Skip to content

Commit

Permalink
Add phi3 converter (#1680)
Browse files Browse the repository at this point in the history
* add phi3 converter

* PhiLoader to Phi3Loader

* fix black
  • Loading branch information
minhthuc2502 authored Apr 26, 2024
1 parent 0527ef7 commit 9d54f5d
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,102 @@ def set_decoder(self, spec, module):
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2)


@register_loader("Phi3Config")
class Phi3Loader(ModelLoader):
@property
def architecture_name(self):
return "AutoModelForCausalLM"

def get_model_spec(self, model):
num_layers = model.config.num_hidden_layers

num_heads = model.config.num_attention_heads
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None

rope_scaling = getattr(model.config, "rope_scaling", None)
if rope_scaling:
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
rotary_scaling_factor = rope_scaling["factor"]

if rotary_scaling_type is None:
raise NotImplementedError(
"RoPE scaling type '%s' is not yet implemented. "
"The following RoPE scaling types are currently supported: %s"
% (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
)
else:
rotary_scaling_type = None
rotary_scaling_factor = 1

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=common_spec.Activation.SWISH,
pre_norm=True,
ffn_glu=True,
rms_norm=True,
rotary_dim=0,
rotary_interleave=False,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
)

self.set_decoder(spec.decoder, model.model)
self.set_linear(spec.decoder.projection, model.lm_head)
return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight

def set_decoder(self, spec, module):
spec.scale_embeddings = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)

for layer_spec, layer in zip(spec.layer, module.layers):
self.set_layer_norm(
layer_spec.self_attention.layer_norm, layer.input_layernorm
)
self.set_layer_norm(
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
)

self.set_linear(
layer_spec.self_attention.linear[0], layer.self_attn.qkv_proj
)
self.set_linear(layer_spec.self_attention.linear[1], layer.self_attn.o_proj)

gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
layer_spec.ffn.linear_0.weight = gate_proj
layer_spec.ffn.linear_0_noact.weight = up_proj
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("RWConfig")
class RWLoader(ModelLoader):
@property
Expand Down

0 comments on commit 9d54f5d

Please sign in to comment.