From 2de8b2bcbe9f25793987aa706d2525a4626e161b Mon Sep 17 00:00:00 2001 From: Dmitrii Mukhutdinov Date: Wed, 15 Nov 2023 11:38:43 +0000 Subject: [PATCH] Better HF-to-CT2 conversion for Whisper model --- python/ctranslate2/converters/transformers.py | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 185ffa50b..0af24b27b 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -890,11 +890,36 @@ def get_model_spec(self, model): return spec def set_config(self, config, model, tokenizer): - config.suppress_ids = model.config.suppress_tokens - config.suppress_ids_begin = model.config.begin_suppress_tokens - config.lang_ids = tokenizer.additional_special_tokens_ids[2:-6] + gen_config = getattr(model, "generation_config", None) + + if gen_config is not None: + config.suppress_ids = gen_config.suppress_tokens + config.suppress_ids_begin = gen_config.begin_suppress_tokens + config.alignment_heads = gen_config.alignment_heads + else: + config.suppress_ids = model.config.suppress_tokens + config.suppress_ids_begin = model.config.begin_suppress_tokens + config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path) + + non_lang_special_tokens = [ + "<|endoftext|>", + "<|startoftranscript|>", + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nocaptions|>", + "<|notimestamps|>", + ] + config.lang_ids = [ + token_id + for token_id, token in zip( + tokenizer.additional_special_tokens_ids, + tokenizer.additional_special_tokens, + ) + if token not in non_lang_special_tokens + ] - config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path) if config.alignment_heads is None: # Use the last half layers for alignment by default. num_layers = model.config.decoder_layers