1+
2+ import argparse
3+ import os
4+ import unittest
5+ from unittest .mock import MagicMock , mock_open , patch
6+
7+ from mlc_llm import utils
8+
9+ from mlc_llm .core import build_model_from_args
10+
11+
12+ class MockMkdir (object ):
13+ def __init__ (self ):
14+ self .received_args = None
15+
16+ def __call__ (self , * args ):
17+ self .received_args = args
18+
19+ class BuildModelTest (unittest .TestCase ):
20+
21+ def setUp (self ):
22+ self ._orig_mkdir = os .mkdir
23+ os .mkdir = MockMkdir ()
24+
25+ self .mock_args = argparse .Namespace ()
26+ self .mock_args .quantization = utils .quantization_schemes ["q8f16_1" ]
27+ self .mock_args .debug_dump = False
28+ self .mock_args .use_cache = False
29+ self .mock_args .sep_embed = False
30+ self .mock_args .build_model_only = True
31+ self .mock_args .use_safetensors = False
32+ self .mock_args .convert_weight_only = False
33+ self .mock_args .no_cutlass_attn = True
34+ self .mock_args .no_cutlass_norm = True
35+ self .mock_args .reuse_lib = True
36+ self .mock_args .artifact_path = "/tmp/"
37+ self .mock_args .model_path = "/tmp/"
38+ self .mock_args .model = "/tmp/"
39+ self .mock_args .target_kind = "cuda"
40+ self .mock_args .max_seq_len = 2048
41+
42+ def tearDown (self ):
43+ os .mkdir = self ._orig_mkdir
44+
45+ @patch ("builtins.open" , new_callable = mock_open , read_data = "data" )
46+ @patch ("json.load" , MagicMock (side_effect = [ {} ]))
47+ def test_llama_model (self , mock_file ):
48+ self .mock_args .model_category = "llama"
49+
50+ build_model_from_args (self .mock_args )
51+
52+ @patch ("builtins.open" , new_callable = mock_open , read_data = "data" )
53+ @patch ("json.load" , MagicMock (side_effect = [ {
54+ "use_parallel_residual" : False ,
55+ "hidden_size" : 32 ,
56+ "intermediate_size" : 32 ,
57+ "num_attention_heads" : 32 ,
58+ "num_hidden_layers" : 28 ,
59+ "vocab_size" : 1024 ,
60+ "rotary_pct" : 1 ,
61+ "rotary_emb_base" : 1 ,
62+ "layer_norm_eps" : 1 ,
63+ } ]))
64+ def test_gpt_neox_model (self , mock_file ):
65+ self .mock_args .model_category = "gpt_neox"
66+ self .mock_args .model = "dolly-test"
67+
68+ build_model_from_args (self .mock_args )
69+
70+ @patch ("builtins.open" , new_callable = mock_open , read_data = "data" )
71+ @patch ("json.load" , MagicMock (side_effect = [ {} ]))
72+ def test_gpt_bigcode_model (self , mock_file ):
73+ self .mock_args .model_category = "gpt_bigcode"
74+ self .mock_args .model = "gpt_bigcode"
75+
76+ build_model_from_args (self .mock_args )
77+
78+ @patch ("builtins.open" , new_callable = mock_open , read_data = "data" )
79+ @patch ("json.load" , MagicMock (side_effect = [ {} ]))
80+ def test_minigpt_model (self , mock_file ):
81+ self .mock_args .model_category = "minigpt"
82+ self .mock_args .model = "minigpt4-7b"
83+
84+ build_model_from_args (self .mock_args )
85+
86+
87+ @patch ("builtins.open" , new_callable = mock_open , read_data = "data" )
88+ @patch ("json.load" , MagicMock (side_effect = [ {
89+ "vocab_size" : 1024 ,
90+ "n_embd" : 32 ,
91+ "n_inner" : 32 ,
92+ "n_head" : 32 ,
93+ "n_layer" : 28 ,
94+ "bos_token_id" : 28 ,
95+ "eos_token_id" : 1 ,
96+ "rotary_dim" : 1 ,
97+ "tie_word_embeddings" : 1 ,
98+ } ]))
99+ def test_gptj_model (self , mock_file ):
100+ self .mock_args .model_category = "gptj"
101+ self .mock_args .model = "gpt-j-"
102+
103+ build_model_from_args (self .mock_args )
104+
105+
106+ @patch ("builtins.open" , new_callable = mock_open , read_data = "data" )
107+ @patch ("json.load" , MagicMock (side_effect = [ {
108+ "num_hidden_layers" : 16 ,
109+ "vocab_size" : 1024 ,
110+ "hidden_size" : 16 ,
111+ "intermediate_size" : 32 ,
112+ } ]))
113+ def test_rwkv_model (self , mock_file ):
114+ self .mock_args .model_category = "rwkv"
115+ self .mock_args .model = "rwkv-"
116+
117+ build_model_from_args (self .mock_args )
118+
119+
120+ @patch ("builtins.open" , new_callable = mock_open , read_data = "data" )
121+ @patch ("json.load" , MagicMock (side_effect = [ { } ]))
122+ def test_chatglm_model (self , mock_file ):
123+ self .mock_args .model_category = "chatglm"
124+ self .mock_args .model = "chatglm2"
125+
126+ build_model_from_args (self .mock_args )
0 commit comments