Skip to content

Commit c88231b

Browse files
committed
whisper checkpoint
1 parent 8606327 commit c88231b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+7416
-25
lines changed

python/mlc_chat/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
MLC Chat is the app runtime of MLC LLM.
44
"""
5-
from . import protocol, serve
5+
6+
# from . import protocol, serve
67
from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig
78
from .libinfo import __version__

python/mlc_chat/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Load MLC LLM library and _ffi_api functions."""
2+
23
import ctypes
34
import os
45
import sys
@@ -24,5 +25,5 @@ def _load_mlc_llm_lib():
2425

2526

2627
# only load once here
27-
if SKIP_LOADING_MLCLLM_SO == "0":
28-
_LIB, _LIB_PATH = _load_mlc_llm_lib()
28+
# if SKIP_LOADING_MLCLLM_SO == "0":
29+
# _LIB, _LIB_PATH = _load_mlc_llm_lib()

python/mlc_chat/compiler_pass/pipeline.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""The compilation pipeline for LLM applications."""
2+
23
from pathlib import Path
34
from typing import Any, Dict, List, Optional
45

python/mlc_chat/interface/compile.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Python entrypoint of compilation."""
2+
23
import dataclasses
34
import math
45
from io import StringIO

python/mlc_chat/model/model.py

+14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .phi import phi_loader, phi_model, phi_quantization
1818
from .qwen import qwen_loader, qwen_model, qwen_quantization
1919
from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization
20+
from .whisper import whisper_loader, whisper_model, whisper_quantization
2021

2122
ModelConfig = Any
2223
"""A ModelConfig is an object that represents a model architecture. It is required to have
@@ -195,4 +196,17 @@ class Model:
195196
"group-quant": stablelm_quantization.group_quant,
196197
},
197198
),
199+
"whisper": Model(
200+
name="whisper",
201+
model=whisper_model.WhisperForConditionalGeneration,
202+
config=whisper_model.WhisperConfig,
203+
source={
204+
"huggingface-torch": whisper_loader.huggingface,
205+
"huggingface-safetensor": whisper_loader.huggingface,
206+
},
207+
quantize={
208+
"no-quant": whisper_quantization.no_quant,
209+
"group-quant": whisper_quantization.group_quant,
210+
},
211+
),
198212
}

python/mlc_chat/model/whisper/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
This file specifies how MLC's Whisper parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
6+
import functools
7+
8+
from mlc_chat.loader import ExternMapping
9+
from mlc_chat.quantization import Quantization
10+
11+
from .whisper_model import WhisperConfig, WhisperForConditionalGeneration
12+
13+
14+
def huggingface(model_config: WhisperConfig, quantization: Quantization) -> ExternMapping:
15+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
16+
the names of HuggingFace PyTorch parameters.
17+
18+
Parameters
19+
----------
20+
model_config : WhisperConfig
21+
The configuration of the GPTNeoX model.
22+
23+
quantization : Quantization
24+
The quantization configuration.
25+
26+
Returns
27+
-------
28+
param_map : ExternMapping
29+
The parameter mapping from MLC to HuggingFace PyTorch.
30+
"""
31+
model = WhisperForConditionalGeneration(model_config)
32+
if quantization is not None:
33+
model.to(quantization.model_dtype)
34+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
35+
spec=model.get_default_spec(),
36+
allow_extern=True,
37+
)
38+
named_parameters = dict(_named_params)
39+
40+
mapping = ExternMapping()
41+
42+
for mlc_name, mlc_param in named_parameters.items():
43+
mapping.add_mapping(
44+
mlc_name,
45+
[mlc_name],
46+
functools.partial(
47+
lambda x, dtype: x.astype(dtype),
48+
dtype=mlc_param.dtype,
49+
),
50+
)
51+
return mapping

0 commit comments

Comments
 (0)