From b1d21dbf847681c8d470cda56104a734a4b59784 Mon Sep 17 00:00:00 2001 From: Dmitrii Mukhutdinov Date: Fri, 17 Nov 2023 07:47:41 +0000 Subject: [PATCH] Use lang_ids from gen config if available --- python/ctranslate2/converters/transformers.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 0af24b27b..f1e361bb7 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -889,18 +889,7 @@ def get_model_spec(self, model): return spec - def set_config(self, config, model, tokenizer): - gen_config = getattr(model, "generation_config", None) - - if gen_config is not None: - config.suppress_ids = gen_config.suppress_tokens - config.suppress_ids_begin = gen_config.begin_suppress_tokens - config.alignment_heads = gen_config.alignment_heads - else: - config.suppress_ids = model.config.suppress_tokens - config.suppress_ids_begin = model.config.begin_suppress_tokens - config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path) - + def _get_lang_ids_from_tokenizer(self, tokenizer): non_lang_special_tokens = [ "<|endoftext|>", "<|startoftranscript|>", @@ -911,7 +900,7 @@ def set_config(self, config, model, tokenizer): "<|nocaptions|>", "<|notimestamps|>", ] - config.lang_ids = [ + return [ token_id for token_id, token in zip( tokenizer.additional_special_tokens_ids, @@ -920,6 +909,20 @@ def set_config(self, config, model, tokenizer): if token not in non_lang_special_tokens ] + def set_config(self, config, model, tokenizer): + gen_config = getattr(model, "generation_config", None) + + if gen_config is not None: + config.suppress_ids = gen_config.suppress_tokens + config.suppress_ids_begin = gen_config.begin_suppress_tokens + config.alignment_heads = gen_config.alignment_heads + config.lang_ids = sorted(gen_config.lang_to_id.values()) + else: + config.suppress_ids = model.config.suppress_tokens + config.suppress_ids_begin = model.config.begin_suppress_tokens + config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path) + config.lang_ids = self._get_lang_ids_from_tokenizer(tokenizer) + if config.alignment_heads is None: # Use the last half layers for alignment by default. num_layers = model.config.decoder_layers