Skip to content

Commit 8bd6918

Browse files
nverkeMasterJH5574
andauthored
[Model] Refactor chatglm config to match others
Add test to test some of this functionality. --------- Co-authored-by: Ruihang Lai <[email protected]>
1 parent 9d44ae7 commit 8bd6918

File tree

3 files changed

+137
-11
lines changed

3 files changed

+137
-11
lines changed

mlc_llm/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,12 @@ def mod_transform_before_build(
364364
max_seq_len = None
365365
if args.max_seq_len > 0:
366366
max_seq_len = args.max_seq_len
367-
elif "max_sequence_length" in config:
368-
max_seq_len = config["max_sequence_length"]
367+
elif hasattr(config, "max_sequence_length"):
368+
max_seq_len = config.max_sequence_length
369369

370370
if max_seq_len:
371371
mod = fuse_split_rotary_embedding(
372-
mod, config["num_attention_heads"], config["hidden_size"], max_seq_len
372+
mod, config.num_attention_heads, config.hidden_size, max_seq_len
373373
)
374374

375375
if args.target_kind == "cuda":

mlc_llm/relax_model/chatglm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747
multi_query_group_num: int = 2,
4848
num_attention_heads: int = 32,
4949
num_layers: int = 28,
50-
seq_length: int = 2048,
50+
max_sequence_length: int = 2048,
5151
padded_vocab_size: int = 65024,
5252
eos_token_id: int = 2,
5353
bos_token_id: int = 0,
@@ -63,7 +63,7 @@ def __init__(
6363
self.multi_query_group_num = multi_query_group_num
6464
self.num_attention_heads = num_attention_heads
6565
self.num_layers = num_layers
66-
self.seq_length = min(2048, seq_length)
66+
self.max_sequence_length = min(2048, max_sequence_length)
6767
self.padded_vocab_size = padded_vocab_size
6868
self.bos_token_id = bos_token_id
6969
self.eos_token_id = eos_token_id
@@ -481,14 +481,14 @@ def __init__(self, config: ChatGLMConfig):
481481
dtype=config.dtype,
482482
)
483483

484-
self.seq_length = config.seq_length
484+
self.seq_length = config.max_sequence_length
485485
rotary_dim = config.kv_channels // 2
486486

487487
self.rotary_pos_emb = RotaryEmbedding(
488488
hidden_size=config.hidden_size,
489489
num_attention_heads=config.num_attention_heads,
490490
position_embedding_base=10000,
491-
max_sequence_length=config.seq_length,
491+
max_sequence_length=config.max_sequence_length,
492492
rotary_dim=rotary_dim,
493493
swizzle_style="glm",
494494
dtype=config.dtype,
@@ -726,7 +726,7 @@ def create_decoding_func(
726726
def create_kv_cache_func(bb: relax.BlockBuilder, config: ChatGLMConfig) -> None:
727727
init_shape = relax.ShapeExpr(
728728
(
729-
config.seq_length,
729+
config.max_sequence_length,
730730
config.multi_query_group_num,
731731
config.hidden_size // config.num_attention_heads,
732732
)
@@ -782,7 +782,7 @@ def get_model(args: argparse.Namespace, hf_config):
782782
create_metadata_func(
783783
bb,
784784
model_name=model,
785-
max_window_size=config.seq_length,
785+
max_window_size=config.max_sequence_length,
786786
stop_tokens=[0],
787787
add_prefix_space=False,
788788
)
@@ -794,8 +794,8 @@ def get_model(args: argparse.Namespace, hf_config):
794794
mod[gv] = func.with_attr(
795795
"tir_var_upper_bound",
796796
{
797-
"n": config.seq_length,
798-
"m": config.seq_length,
797+
"n": config.max_sequence_length,
798+
"m": config.max_sequence_length,
799799
},
800800
)
801801

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
2+
import argparse
3+
import os
4+
import unittest
5+
from unittest.mock import MagicMock, mock_open, patch
6+
7+
from mlc_llm import utils
8+
9+
from mlc_llm.core import build_model_from_args
10+
11+
12+
class MockMkdir(object):
13+
def __init__(self):
14+
self.received_args = None
15+
16+
def __call__(self, *args):
17+
self.received_args = args
18+
19+
class BuildModelTest(unittest.TestCase):
20+
21+
def setUp(self):
22+
self._orig_mkdir = os.mkdir
23+
os.mkdir = MockMkdir()
24+
25+
self.mock_args = argparse.Namespace()
26+
self.mock_args.quantization = utils.quantization_schemes["q8f16_1"]
27+
self.mock_args.debug_dump = False
28+
self.mock_args.use_cache = False
29+
self.mock_args.sep_embed = False
30+
self.mock_args.build_model_only = True
31+
self.mock_args.use_safetensors = False
32+
self.mock_args.convert_weight_only = False
33+
self.mock_args.no_cutlass_attn = True
34+
self.mock_args.no_cutlass_norm = True
35+
self.mock_args.reuse_lib = True
36+
self.mock_args.artifact_path = "/tmp/"
37+
self.mock_args.model_path = "/tmp/"
38+
self.mock_args.model = "/tmp/"
39+
self.mock_args.target_kind = "cuda"
40+
self.mock_args.max_seq_len = 2048
41+
42+
def tearDown(self):
43+
os.mkdir = self._orig_mkdir
44+
45+
@patch("builtins.open", new_callable=mock_open, read_data="data")
46+
@patch("json.load", MagicMock(side_effect = [ {} ]))
47+
def test_llama_model(self, mock_file):
48+
self.mock_args.model_category = "llama"
49+
50+
build_model_from_args(self.mock_args)
51+
52+
@patch("builtins.open", new_callable=mock_open, read_data="data")
53+
@patch("json.load", MagicMock(side_effect = [ {
54+
"use_parallel_residual": False,
55+
"hidden_size": 32,
56+
"intermediate_size": 32,
57+
"num_attention_heads": 32,
58+
"num_hidden_layers": 28,
59+
"vocab_size": 1024,
60+
"rotary_pct": 1,
61+
"rotary_emb_base": 1,
62+
"layer_norm_eps": 1,
63+
} ]))
64+
def test_gpt_neox_model(self, mock_file):
65+
self.mock_args.model_category = "gpt_neox"
66+
self.mock_args.model = "dolly-test"
67+
68+
build_model_from_args(self.mock_args)
69+
70+
@patch("builtins.open", new_callable=mock_open, read_data="data")
71+
@patch("json.load", MagicMock(side_effect = [ {} ]))
72+
def test_gpt_bigcode_model(self, mock_file):
73+
self.mock_args.model_category = "gpt_bigcode"
74+
self.mock_args.model = "gpt_bigcode"
75+
76+
build_model_from_args(self.mock_args)
77+
78+
@patch("builtins.open", new_callable=mock_open, read_data="data")
79+
@patch("json.load", MagicMock(side_effect = [ {} ]))
80+
def test_minigpt_model(self, mock_file):
81+
self.mock_args.model_category = "minigpt"
82+
self.mock_args.model = "minigpt4-7b"
83+
84+
build_model_from_args(self.mock_args)
85+
86+
87+
@patch("builtins.open", new_callable=mock_open, read_data="data")
88+
@patch("json.load", MagicMock(side_effect = [ {
89+
"vocab_size": 1024,
90+
"n_embd": 32,
91+
"n_inner": 32,
92+
"n_head": 32,
93+
"n_layer": 28,
94+
"bos_token_id": 28,
95+
"eos_token_id": 1,
96+
"rotary_dim": 1,
97+
"tie_word_embeddings": 1,
98+
} ]))
99+
def test_gptj_model(self, mock_file):
100+
self.mock_args.model_category = "gptj"
101+
self.mock_args.model = "gpt-j-"
102+
103+
build_model_from_args(self.mock_args)
104+
105+
106+
@patch("builtins.open", new_callable=mock_open, read_data="data")
107+
@patch("json.load", MagicMock(side_effect = [ {
108+
"num_hidden_layers": 16,
109+
"vocab_size": 1024,
110+
"hidden_size": 16,
111+
"intermediate_size": 32,
112+
} ]))
113+
def test_rwkv_model(self, mock_file):
114+
self.mock_args.model_category = "rwkv"
115+
self.mock_args.model = "rwkv-"
116+
117+
build_model_from_args(self.mock_args)
118+
119+
120+
@patch("builtins.open", new_callable=mock_open, read_data="data")
121+
@patch("json.load", MagicMock(side_effect = [ { } ]))
122+
def test_chatglm_model(self, mock_file):
123+
self.mock_args.model_category = "chatglm"
124+
self.mock_args.model = "chatglm2"
125+
126+
build_model_from_args(self.mock_args)

0 commit comments

Comments
 (0)