diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 1b2d0c98f..6924ac637 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1530,6 +1530,58 @@ def set_decoder(self, spec, module): self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2) +@register_loader("PhiConfig") +class PhiLoader(ModelLoader): + @property + def architecture_name(self): + return "AutoModelForCausalLM" + + def get_model_spec(self, model): + spec = transformer_spec.TransformerDecoderModelSpec.from_config( + num_layers=model.config.n_layer, + num_heads=model.config.n_head, + pre_norm=True, + activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function], + rotary_dim=model.config.rotary_dim, + rotary_interleave=False, + parallel_residual=True, + shared_layer_norm=True, + ) + + self.set_decoder(spec.decoder, model.transformer) + self.set_linear(spec.decoder.projection, model.lm_head.linear) + self.set_layer_norm(spec.decoder.layer_norm, model.lm_head.ln) + return spec + + def get_vocabulary(self, model, tokenizer): + tokens = super().get_vocabulary(model, tokenizer) + + extra_ids = model.config.vocab_size - len(tokens) + for i in range(extra_ids): + tokens.append("" % i) + + return tokens + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_config(self, config, model, tokenizer): + config.bos_token = tokenizer.bos_token + config.eos_token = tokenizer.eos_token + config.unk_token = tokenizer.unk_token + + def set_decoder(self, spec, module): + spec.scale_embeddings = False + self.set_embeddings(spec.embeddings, module.embd.wte) + + for layer_spec, layer in zip(spec.layer, module.h): + self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln) + self.set_linear(layer_spec.self_attention.linear[0], layer.mixer.Wqkv) + self.set_linear(layer_spec.self_attention.linear[1], layer.mixer.out_proj) + self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc1) + self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2) + + @register_loader("RWConfig") class RWLoader(ModelLoader): @property