Skip to content

Commit 9434126

Browse files
authored
try to support rwkv-4-world (#848)
* fix small bug * fix conflict * cmake * refine * fix compile bug * run rwkv7b success * restruct * update tokenizers cpp to support rwkv * fix prompt * refine * move logical to conv_template.cc * fix comment * refine
1 parent fc42a1d commit 9434126

File tree

6 files changed

+45
-3
lines changed

6 files changed

+45
-3
lines changed

cpp/conv_templates.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,31 @@ Conversation RWKV() {
215215
return conv;
216216
}
217217

218+
Conversation RWKVWorld() {
219+
const std::string kUserPrefix = "User: ";
220+
const std::string kAssistantPrefix = "Assistant: Hi. I am your assistant and I will provide expert "
221+
"full response in full details. Please feel free to ask any question and I will always answer it.";
222+
const std::string kDoubleNewLine = "\n\n";
223+
const std::string prompt =
224+
"(" + kUserPrefix + "hi" + kDoubleNewLine + kAssistantPrefix + kDoubleNewLine + ")";
225+
Conversation conv;
226+
conv.name = "rwkv-world";
227+
conv.system = prompt;
228+
conv.roles = {"User", "Assistant"};
229+
conv.messages = {};
230+
conv.separator_style = SeparatorStyle::kSepRoleMsg;
231+
conv.offset = 0;
232+
conv.seps = {"\n\n"};
233+
conv.role_msg_sep = ": ";
234+
conv.role_empty_sep = ":";
235+
conv.stop_str = "\n\n";
236+
// TODO(mlc-team): add eos to mlc-chat-config
237+
// and remove eos from stop token setting.
238+
conv.stop_tokens = {0};
239+
conv.add_bos = false;
240+
return conv;
241+
}
242+
218243
Conversation Gorilla() {
219244
Conversation conv;
220245
conv.name = "gorilla_v0";
@@ -532,6 +557,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
532557
{"vicuna_v1.1", VicunaV11},
533558
{"conv_one_shot", ConvOneShot},
534559
{"redpajama_chat", RedPajamaChat},
560+
{"rwkv_world", RWKVWorld},
535561
{"rwkv", RWKV},
536562
{"gorilla", Gorilla},
537563
{"guanaco", Guanaco},

cpp/llm_chat.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
5656
std::filesystem::path path(_path);
5757
std::filesystem::path sentencepiece;
5858
std::filesystem::path huggingface;
59+
std::filesystem::path rwkvworld;
5960
CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path;
6061
if (std::filesystem::is_directory(path)) {
6162
sentencepiece = path / "tokenizer.model";
6263
huggingface = path / "tokenizer.json";
64+
rwkvworld = path / "tokenizer_model";
6365
// Check ByteLevelBPE
6466
{
6567
std::filesystem::path merges_path = path / "merges.txt";
@@ -76,13 +78,17 @@ std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
7678
} else {
7779
sentencepiece = path.parent_path() / "tokenizer.model";
7880
huggingface = path.parent_path() / "tokenizer.json";
81+
rwkvworld = path.parent_path() / "tokenizer_model";
7982
}
8083
if (std::filesystem::exists(sentencepiece)) {
8184
return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string()));
8285
}
8386
if (std::filesystem::exists(huggingface)) {
8487
return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string()));
8588
}
89+
if (std::filesystem::exists(rwkvworld)) {
90+
return Tokenizer::FromBlobRWKVWorld(rwkvworld.string());
91+
}
8692
LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
8793
}
8894

mlc_llm/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def build_model_from_args(args: argparse.Namespace):
556556
mod, param_manager, params, model_config = minigpt.get_model(args)
557557
elif args.model_category == "gptj":
558558
mod, param_manager, params, model_config = gptj.get_model(args, config)
559-
elif args.model_category == "rwkv":
559+
elif args.model_category == "rwkv" or args.model_category == "rwkv_world":
560560
mod, param_manager, params, model_config = rwkv.get_model(args, config)
561561
elif args.model_category == "chatglm":
562562
mod, param_manager, params, model_config = chatglm.get_model(args, config)
@@ -572,7 +572,7 @@ def build_model_from_args(args: argparse.Namespace):
572572
utils.save_params(new_params, args.artifact_path)
573573
if args.model_category != "minigpt":
574574
utils.copy_tokenizer(args)
575-
if args.model_category == "rwkv":
575+
if args.model_category == "rwkv" or args.model_category == "rwkv_world":
576576
# TODO: refactor config into model definition
577577
dump_mlc_chat_config(args, top_p=0.6, temperature=1.2, repetition_penalty=0.996)
578578
else:

mlc_llm/dispatch/dispatch_tir_operator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def __init__(self, model: str):
2121

2222
elif model == "rwkv":
2323
lookup = None
24+
25+
elif model == "rwkv_world":
26+
lookup = None
2427

2528
elif model == "gptj":
2629
lookup = None

mlc_llm/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,21 @@ def argparse_postproc_common(args: argparse.Namespace) -> None:
3636
"moss-moon-003-sft": "gptj",
3737
"moss-moon-003-base": "gptj",
3838
"rwkv-": "rwkv",
39+
"rwkv_world": "rwkv_world",
3940
"minigpt": "minigpt",
4041
}
4142
try:
4243
with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f:
4344
config = json.load(i_f)
4445
args.model_category = config["model_type"]
46+
model_path_lower = args.model_path.lower()
47+
if "rwkv" in model_path_lower and "world" in model_path_lower:
48+
args.model_category = "rwkv_world"
4549
except Exception:
4650
args.model_category = ""
4751
model = args.model.lower()
52+
if "rwkv" in model and "world" in model:
53+
model = "rwkv_world"
4854
for prefix, override_category in model_category_override.items():
4955
if model.startswith(prefix):
5056
args.model_category = override_category
@@ -67,6 +73,7 @@ def argparse_postproc_common(args: argparse.Namespace) -> None:
6773
"gpt-j-": "LM",
6874
"open_llama": "LM",
6975
"rwkv-": "rwkv",
76+
"rwkv_world": "rwkv_world",
7077
"gorilla-": "gorilla",
7178
"guanaco": "guanaco",
7279
"wizardlm-7b": "wizardlm_7b", # first get rid of 7b

0 commit comments

Comments
 (0)