From 2870fe3ddce49c85ecab4f84fc5e4b01b3a740fe Mon Sep 17 00:00:00 2001 From: Minh-Thuc <46375464+minhthuc2502@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:37:42 +0100 Subject: [PATCH] Support qwen2 (#1820) * support qwen2 * fix flake --- README.md | 2 +- python/ctranslate2/converters/transformers.py | 108 ++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bfb64c851..fb91f5eb3 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ The project implements a custom runtime that applies many performance optimizati The following model types are currently supported: * Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper -* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon +* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2 * Encoder-only models: BERT, DistilBERT, XLM-RoBERTa Compatible models should be first converted into an optimized model format. The library includes converters for multiple frameworks: diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 7655662dd..d5f935f95 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1956,6 +1956,114 @@ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2): gc.collect() +@register_loader("Qwen2Config") +class Qwen2Loader(ModelLoader): + @property + def architecture_name(self): + return "Qwen2ForCausalLM" + + def get_model_spec(self, model): + num_layers = model.config.num_hidden_layers + + num_heads = model.config.num_attention_heads + num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads) + if num_heads_kv == num_heads: + num_heads_kv = None + + rope_scaling = getattr(model.config, "rope_scaling", None) + if rope_scaling: + rope_type = rope_scaling.get("type") or rope_scaling["rope_type"] + rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type) + rotary_scaling_factor = rope_scaling["factor"] + + if rotary_scaling_type is None: + raise NotImplementedError( + "RoPE scaling type '%s' is not yet implemented. " + "The following RoPE scaling types are currently supported: %s" + % (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys())) + ) + else: + rotary_scaling_type = None + rotary_scaling_factor = 1 + + spec = transformer_spec.TransformerDecoderModelSpec.from_config( + num_layers, + num_heads, + activation=common_spec.Activation.SWISH, + pre_norm=True, + ffn_glu=True, + rms_norm=True, + rotary_dim=0, + rotary_interleave=False, + rotary_scaling_type=rotary_scaling_type, + rotary_scaling_factor=rotary_scaling_factor, + rotary_base=getattr(model.config, "rope_theta", 10000), + num_heads_kv=num_heads_kv, + ) + + self.set_decoder(spec.decoder, model.model) + self.set_linear(spec.decoder.projection, model.lm_head) + return spec + + def get_vocabulary(self, model, tokenizer): + tokens = super().get_vocabulary(model, tokenizer) + + extra_ids = model.config.vocab_size - len(tokens) + for i in range(extra_ids): + tokens.append("" % i) + return tokens + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_config(self, config, model, tokenizer): + config.bos_token = ( + tokenizer.bos_token + if tokenizer.bos_token is not None + else tokenizer.pad_token + ) + config.eos_token = tokenizer.eos_token + config.unk_token = ( + tokenizer.unk_token if tokenizer.unk_token is not None else "" + ) + config.layer_norm_epsilon = model.config.rms_norm_eps + + def set_layer_norm(self, spec, layer_norm): + spec.gamma = layer_norm.weight + + def set_decoder(self, spec, module): + spec.scale_embeddings = False + self.set_embeddings(spec.embeddings, module.embed_tokens) + self.set_layer_norm(spec.layer_norm, module.norm) + + for layer_spec, layer in zip(spec.layer, module.layers): + self.set_layer_norm( + layer_spec.self_attention.layer_norm, layer.input_layernorm + ) + self.set_layer_norm( + layer_spec.ffn.layer_norm, layer.post_attention_layernorm + ) + + split_layers = [common_spec.LinearSpec() for _ in range(3)] + self.set_linear(split_layers[0], layer.self_attn.q_proj) + self.set_linear(split_layers[1], layer.self_attn.k_proj) + self.set_linear(split_layers[2], layer.self_attn.v_proj) + + utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers) + self.set_linear( + layer_spec.self_attention.linear[1], + layer.self_attn.o_proj, + ) + + self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj) + self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj) + self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj) + + delattr(layer, "self_attn") + delattr(layer, "mlp") + gc.collect() + + @register_loader("MixFormerSequentialConfig") class MixFormerSequentialLoader(ModelLoader): @property