From 28812c1910aaa78a3c3cbdec71e9c0183a26f85d Mon Sep 17 00:00:00 2001 From: Minh-Thuc <46375464+minhthuc2502@users.noreply.github.com> Date: Wed, 10 Apr 2024 12:18:35 +0200 Subject: [PATCH] fix gemma bug (#1660) * fix gemma bug * fix black --- python/ctranslate2/converters/transformers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index cc5176806..29c56f2b7 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1291,10 +1291,16 @@ def get_model_spec(self, model): if num_heads_kv == num_heads: num_heads_kv = None + activation_config = getattr( + model.config, "hidden_activation", "gelu_pytorch_tanh" + ) + spec = transformer_spec.TransformerDecoderModelSpec.from_config( num_layers, num_heads, - activation=common_spec.Activation.GELU, + activation=common_spec.Activation.GELU + if activation_config == "gelu" + else common_spec.Activation.GELUTanh, pre_norm=True, ffn_glu=True, rms_norm=True,