diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 83033b203..d90ff3569 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -357,7 +357,17 @@ def set_attention(self, spec, attention, self_attention=False): self.set_linear(spec.linear[-1], attention.out_proj) def set_common_layers(self, spec, module): - spec.scale_embeddings = module.embed_scale + import math + + if not hasattr(module, "embed_scale"): + embed_scale = ( + math.sqrt(module.config.d_model) + if module.config.scale_embedding + else 1.0 + ) + else: + embed_scale = module.embed_scale + spec.scale_embeddings = embed_scale self.set_position_encodings(spec.position_encodings, module.embed_positions) self.set_embeddings( (