-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
221 lines (174 loc) · 8.44 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import time
import torch
from typing import Optional
from threading import Thread, Event
import gradio as gr
from transformers import AutoTokenizer, TextIteratorStreamer
from RAG.utils import Config
from RAG.VectorBase import load_vector_database
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.schema import NodeWithScore
from llama_index.core.vector_stores import VectorStoreQuery
from ipex_llm.transformers import AutoModelForCausalLM
config = Config()
os.environ["OMP_NUM_THREADS"] = config.get("omp_num_threads")
# 设置嵌入模型
embed_model = HuggingFaceEmbedding(model_name=config.get("embedding_model_path"))
# 设置语言模型
# llm = setup_local_llm(config)
persist_dir = config.get("persist_dir")
vector_store = load_vector_database(persist_dir, "load")
##############################rag start#########################
# 加载模型和tokenizer
load_path = config.get("model_path") # 模型路径
model = AutoModelForCausalLM.load_low_bit(load_path, trust_remote_code=True) # 加载低位模型
tokenizer = AutoTokenizer.from_pretrained(load_path, trust_remote_code=True) # 加载对应的tokenizer
# 将模型移动到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检查是否有GPU可用
model = model.to(device) # 将模型移动到选定的设备上
# 创建 TextIteratorStreamer,用于流式生成文本
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 创建一个停止事件,用于控制生成过程的中断
stop_event = Event()
# 定义用户输入处理函数
def user(user_message, history):
return "", history + [[user_message, None]] # 返回空字符串和更新后的历史记录
# 定义机器人回复生成函数
def bot(history, type):
stop_event.clear() # 重置停止事件
prompt = history[-1][0] # 获取最新的用户输入
######################rag query#################################
# 设置查询
# config.get("question") = prompt
query_str = prompt
query_embedding = embed_model.get_query_embedding(prompt)
# 执行向量存储检索
print("开始执行向量存储检索")
query_mode = "default"
vector_store_query = VectorStoreQuery(
query_embedding=query_embedding, similarity_top_k=2, mode=query_mode
)
query_result = vector_store[type].query(vector_store_query)
# 处理查询结果
print("开始处理检索结果")
nodes_with_scores = []
for index, node in enumerate(query_result.nodes):
score: Optional[float] = None
if query_result.similarities is not None:
score = query_result.similarities[index]
nodes_with_scores.append(NodeWithScore(node=node, score=score))
# Convert retrieved documents into context strings
context_strings = [node.node.text for node in nodes_with_scores]
context_scores = [node.score for node in nodes_with_scores]
# Cutoff of the score/Hyperparameter,
# If the score is small, the query_result has no meaning for the question.
# TODO: advanced tricks in
# https://docs.llamaindex.ai/en/stable/module_guides/querying/node_postprocessors/node_postprocessors/
if context_scores[0] < 0.3:
context_strings = [""]
# Concatenate context strings to form a single context string for the LLM
context_string = "\n".join(context_strings)
print("Scores list")
print(context_scores)
'''
# 设置检索器
retriever = rag.VectorDBRetriever(
vector_store, embed_model, query_mode="default", similarity_top_k=1
)
print(f"Query engine created with retriever: {type(retriever).__name__}")
print(f"Query string length: {len(query_str)}")
print(f"Query string: {query_str}")
# 创建查询引擎
print("准备与llm对话")
# synth = get_response_synthesizer(streaming=True,llm=llm)
query_engine = RetrieverQueryEngine.from_args(retriever, llm=llm)
# 执行查询
print("开始RAG最后生成")
start_time_rag = time.time()
response_rag = query_engine.query(query_str)
print(f"\n\nRAG最后生成完成,用时: {end_time - start_time_rag:.2f} 秒")
print(str(response_rag))
'''
######################rag query end#################################
'''
chat with history
(tiny llms have some problems with chatting, so try to do muti tests)
'''
# TODO: more awesome chat templates
messages = []
for user_msg, response in history:
if user_msg and not response: # 如果当前正在处理的用户消息没有响应
break # 结束循环,只考虑之前的对话
messages.extend([
{"role": "user", "content": user_msg},
{"role": "assistant", "content": response}
])
messages.extend([
{"role": "user", "content":
f'''{context_string}\n
请根据前面的知识参考, 回应下面用户的消息: {prompt}
你是一个心理咨询AI助手, 你的目标是帮助用户舒服地分享他们的想法和情感。你的回答应该是同情、鼓励和支持的,同时保持温暖和温和的语气。
要求:你的回复应该是简单温和的一句话,至少包括2个部分: 承认他们的情绪并确认他们的感受;引导来访者继续阐述他们的感受和体验 (必须有这一步)
示例:
用户:我最近一直感到很焦虑。
回应:听到你感到焦虑,我很抱歉。焦虑确实让人很难受。你觉得是什么原因让你感到特别焦虑呢?我在这里聆听你。
再次强调:请一定要记得继续引导来访者倾诉!!!
'''}
])
print(messages)
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # 应用聊天模板
model_inputs = tokenizer([text], return_tensors="pt").to(device) # 对输入进行编码并移到指定设备
print(f"\n用户输入: {prompt}")
print("模型输出: ", end="", flush=True)
start_time = time.time() # 记录开始时间
# 设置生成参数
generation_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=512, # 最大生成512个新token
do_sample=True, # 使用采样
top_p=0.7, # 使用top-p采样
temperature=0.95, # 控制生成的随机性
)
# 在新线程中运行模型生成
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer: # 迭代生成的文本流
if stop_event.is_set(): # 检查是否需要停止生成
print("\n生成被用户停止")
break
generated_text += new_text
print(new_text, end="", flush=True)
history[-1][1] = generated_text # 更新历史记录中的回复
yield history # 逐步返回更新的历史记录
end_time = time.time()
print(f"\n\n生成完成,用时: {end_time - start_time:.2f} 秒")
# 定义停止生成函数
def stop_generation():
stop_event.set() # 设置停止事件
# 使用Gradio创建Web界面
with gr.Blocks() as demo:
gr.Markdown("# MindEaseAI Chatbot")
chatbot = gr.Chatbot(label="AI 心理咨询助手", show_label=True) # 聊天界面组件
type_selector = gr.Dropdown(
choices=[("家庭关系", "type1"), ("个人成长与心理", "type2"), ("人际关系与社会适应", "type3"), ("职业与学习", "type4")],
label="选择类型"
)
msg = gr.Textbox(placeholder="请输入您的问题或感受...", label="您的消息") # 用户输入文本框
clear = gr.Button("清除") # 清除按钮
stop = gr.Button("停止生成") # 停止生成按钮
# 设置用户输入提交后的处理流程
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, type_selector], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False) # 清除按钮功能
stop.click(stop_generation, queue=False) # 停止生成按钮功能
if __name__ == "__main__":
print("启动 Gradio 界面...")
demo.queue() # 启用队列处理请求
# 提示用户输入DSW号
dsw_number = input("请输入DSW号 (例如: 525085)")
root_path = f"/dsw-{dsw_number}/proxy/7860/"
demo.launch(root_path=root_path, share=True) # 兼容魔搭情况下的路由