Skip to content

Commit

Permalink
Better HF-to-CT2 conversion for Whisper model
Browse files Browse the repository at this point in the history
  • Loading branch information
flyingleafe committed Nov 15, 2023
1 parent adc8262 commit 2de8b2b
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,11 +890,36 @@ def get_model_spec(self, model):
return spec

def set_config(self, config, model, tokenizer):
config.suppress_ids = model.config.suppress_tokens
config.suppress_ids_begin = model.config.begin_suppress_tokens
config.lang_ids = tokenizer.additional_special_tokens_ids[2:-6]
gen_config = getattr(model, "generation_config", None)

if gen_config is not None:
config.suppress_ids = gen_config.suppress_tokens
config.suppress_ids_begin = gen_config.begin_suppress_tokens
config.alignment_heads = gen_config.alignment_heads
else:
config.suppress_ids = model.config.suppress_tokens
config.suppress_ids_begin = model.config.begin_suppress_tokens
config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)

non_lang_special_tokens = [
"<|endoftext|>",
"<|startoftranscript|>",
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nocaptions|>",
"<|notimestamps|>",
]
config.lang_ids = [
token_id
for token_id, token in zip(
tokenizer.additional_special_tokens_ids,
tokenizer.additional_special_tokens,
)
if token not in non_lang_special_tokens
]

config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)
if config.alignment_heads is None:
# Use the last half layers for alignment by default.
num_layers = model.config.decoder_layers
Expand Down

0 comments on commit 2de8b2b

Please sign in to comment.