-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathweb_infer.py
More file actions
99 lines (81 loc) · 3.43 KB
/
web_infer.py
File metadata and controls
99 lines (81 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import threading
import mindspore as ms
from mindspore import context
from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from mindnlp.peft import PeftModel
import gradio as gr
# =================配置区域=================
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 路径配置
MODEL_PATH = "/home/ma-user/work/pretrained/Qwen/Qwen2.5-7B-Instruct"
ADAPTER_DIR = "./final_lora_output"
MERGED_DIR = "./merged_model"
context.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend", device_id=0)
# =================模型加载=================
print("正在加载模型...")
use_merged = os.path.isdir(MERGED_DIR) and bool(os.listdir(MERGED_DIR)) if os.path.exists(MERGED_DIR) else False
if use_merged:
tokenizer = AutoTokenizer.from_pretrained(MERGED_DIR)
model = AutoModelForCausalLM.from_pretrained(MERGED_DIR, ms_dtype=ms.float16)
else:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
base_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, ms_dtype=ms.float16)
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
model = model.to("npu:0")
model.set_train(False)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
print("模型加载完毕!")
# =================推理逻辑=================
def build_prompt(tokenizer, instruction, user_input=""):
system = "你是严谨的中文法律助手。"
if hasattr(tokenizer, "apply_chat_template"):
content = instruction + ("\n" + user_input if user_input else "")
messages = [{"role": "system", "content": system}, {"role": "user", "content": content}]
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
prefix = "系统:你是严谨的中文法律助手。\n用户:"
return f"{prefix}{instruction}\n助手:"
def predict(message, history):
# 内部固定参数
max_len = 1024
temperature = 0.7
top_p = 0.9
full_prompt = build_prompt(tokenizer, message)
inputs = tokenizer(full_prompt, return_tensors="ms")
inputs = {k: v.to("npu:0") for k, v in inputs.items()}
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs["input_ids"],
max_new_tokens=max_len,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.pad_token_id,
streamer=streamer
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token
yield partial_message
# =================搭建网页=================
with gr.Blocks(title="法律大模型") as demo:
gr.Markdown("# ⚖️ 法律大模型助手 (MindSpore版)")
chatbot_config = gr.Chatbot(
height=600,
buttons=["copy"],
)
gr.ChatInterface(
predict,
chatbot=chatbot_config, # 传入自定义的 chatbot
examples=[
["某人在交通事故中受到了腹壁穿透创伤,该如何鉴定?"],
["盗窃罪的立案标准是什么?"],
["请说明注册商标的申请流程?"]
],
description="基于 Qwen2.5 + LoRA 微调的法律问答助手"
)
if __name__ == "__main__":
demo.queue().launch(share=True, server_name="0.0.0.0")