From 3f90226c92e5923640a2b7357561d99f022f361e Mon Sep 17 00:00:00 2001 From: hkwon Date: Thu, 12 Sep 2024 21:11:21 -0700 Subject: [PATCH] patch from https://github.com/OpenNMT/CTranslate2/issues/1711 --- python/ctranslate2/converters/transformers.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 83033b203..d90ff3569 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -357,7 +357,17 @@ def set_attention(self, spec, attention, self_attention=False): self.set_linear(spec.linear[-1], attention.out_proj) def set_common_layers(self, spec, module): - spec.scale_embeddings = module.embed_scale + import math + + if not hasattr(module, "embed_scale"): + embed_scale = ( + math.sqrt(module.config.d_model) + if module.config.scale_embedding + else 1.0 + ) + else: + embed_scale = module.embed_scale + spec.scale_embeddings = embed_scale self.set_position_encodings(spec.position_encodings, module.embed_positions) self.set_embeddings( (