Skip to content

Commit

Permalink
Use lang_ids from gen config if available
Browse files Browse the repository at this point in the history
  • Loading branch information
flyingleafe committed Nov 17, 2023
1 parent 6b02733 commit b1d21db
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,18 +889,7 @@ def get_model_spec(self, model):

return spec

def set_config(self, config, model, tokenizer):
gen_config = getattr(model, "generation_config", None)

if gen_config is not None:
config.suppress_ids = gen_config.suppress_tokens
config.suppress_ids_begin = gen_config.begin_suppress_tokens
config.alignment_heads = gen_config.alignment_heads
else:
config.suppress_ids = model.config.suppress_tokens
config.suppress_ids_begin = model.config.begin_suppress_tokens
config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)

def _get_lang_ids_from_tokenizer(self, tokenizer):
non_lang_special_tokens = [
"<|endoftext|>",
"<|startoftranscript|>",
Expand All @@ -911,7 +900,7 @@ def set_config(self, config, model, tokenizer):
"<|nocaptions|>",
"<|notimestamps|>",
]
config.lang_ids = [
return [
token_id
for token_id, token in zip(
tokenizer.additional_special_tokens_ids,
Expand All @@ -920,6 +909,20 @@ def set_config(self, config, model, tokenizer):
if token not in non_lang_special_tokens
]

def set_config(self, config, model, tokenizer):
gen_config = getattr(model, "generation_config", None)

if gen_config is not None:
config.suppress_ids = gen_config.suppress_tokens
config.suppress_ids_begin = gen_config.begin_suppress_tokens
config.alignment_heads = gen_config.alignment_heads
config.lang_ids = sorted(gen_config.lang_to_id.values())
else:
config.suppress_ids = model.config.suppress_tokens
config.suppress_ids_begin = model.config.begin_suppress_tokens
config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)
config.lang_ids = self._get_lang_ids_from_tokenizer(tokenizer)

if config.alignment_heads is None:
# Use the last half layers for alignment by default.
num_layers = model.config.decoder_layers
Expand Down

0 comments on commit b1d21db

Please sign in to comment.