diff --git a/whisper/convert.py b/whisper/convert.py index 9cc8b861b..0d7c82762 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -59,6 +59,19 @@ } +def _get_model_variant(name_or_path: str) -> str | None: + """Extract model variant for alignment heads lookup.""" + if name_or_path in _ALIGNMENT_HEADS: + return name_or_path + + # Extract from repo name like "openai/whisper-large-v3" + name = name_or_path.split("/")[-1] + if name.startswith("whisper-"): + return name[8:] # Remove "whisper-" prefix + + return None + + def _download(url: str, root: str) -> str: os.makedirs(root, exist_ok=True) @@ -156,10 +169,11 @@ def load_torch_weights_and_config( if download_root is None: download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper") - # todo: accept alignment_heads of local Pytorch checkpoint - alignment_heads = None + # Look up alignment heads using normalized variant name + variant = _get_model_variant(name_or_path) + alignment_heads = _ALIGNMENT_HEADS.get(variant) if variant else None + if name_or_path in _MODELS: - alignment_heads = _ALIGNMENT_HEADS[name_or_path] name_or_path = _download(_MODELS[name_or_path], download_root) elif not Path(name_or_path).exists(): # Try downloading from HF