Skip to content

Commit b7eb9ab

Browse files
[Feature]: allow model mutex override in core_functional.py (binary-husky#1708)
* allow_core_func_specify_model * change arg name * 模型覆盖支持热更新&当模型覆盖指向不存在的模型时报错 * allow model mutex override --------- Co-authored-by: binary-husky <[email protected]>
1 parent 881a596 commit b7eb9ab

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

Diff for: core_functional.py

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def get_core_functions():
3333
"AutoClearHistory": False,
3434
# [6] 文本预处理 (可选参数,默认 None,举例:写个函数移除所有的换行符)
3535
"PreProcess": None,
36+
# [7] 模型选择 (可选参数。如不设置,则使用当前全局模型;如设置,则用指定模型覆盖全局模型。)
37+
# "ModelOverride": "gpt-3.5-turbo", # 主要用途:强制点击此基础功能按钮时,使用指定的模型。
3638
},
3739

3840

Diff for: request_llms/bridge_all.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,13 @@ def decode(self, *args, **kwargs):
906906
AVAIL_LLM_MODELS += [azure_model_name]
907907

908908

909+
# -=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=-=-=
910+
# -=-=-=-=-=-=-=-=-=- ☝️ 以上是模型路由 -=-=-=-=-=-=-=-=-=
911+
# -=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=-=-=
912+
913+
# -=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=-=-=
914+
# -=-=-=-=-=-=-= 👇 以下是多模型路由切换函数 -=-=-=-=-=-=-=
915+
# -=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=-=-=
909916

910917

911918
def LLM_CATCH_EXCEPTION(f):
@@ -942,13 +949,11 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list, sys
942949
model = llm_kwargs['llm_model']
943950
n_model = 1
944951
if '&' not in model:
945-
946-
# 如果只询问1个大语言模型:
952+
# 如果只询问“一个”大语言模型(多数情况):
947953
method = model_info[model]["fn_without_ui"]
948954
return method(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience)
949955
else:
950-
951-
# 如果同时询问多个大语言模型,这个稍微啰嗦一点,但思路相同,您不必读这个else分支
956+
# 如果同时询问“多个”大语言模型,这个稍微啰嗦一点,但思路相同,您不必读这个else分支
952957
executor = ThreadPoolExecutor(max_workers=4)
953958
models = model.split('&')
954959
n_model = len(models)
@@ -1001,8 +1006,26 @@ def mutex_manager(window_mutex, observe_window):
10011006
res = '<br/><br/>\n\n---\n\n'.join(return_string_collect)
10021007
return res
10031008

1004-
1005-
def predict(inputs:str, llm_kwargs:dict, *args, **kwargs):
1009+
# 根据基础功能区 ModelOverride 参数调整模型类型,用于 `predict` 中
1010+
import importlib
1011+
import core_functional
1012+
def execute_model_override(llm_kwargs, additional_fn, method):
1013+
functional = core_functional.get_core_functions()
1014+
if 'ModelOverride' in functional[additional_fn]:
1015+
# 热更新Prompt & ModelOverride
1016+
importlib.reload(core_functional)
1017+
functional = core_functional.get_core_functions()
1018+
model_override = functional[additional_fn]['ModelOverride']
1019+
if model_override not in model_info:
1020+
raise ValueError(f"模型覆盖参数 '{model_override}' 指向一个暂不支持的模型,请检查配置文件。")
1021+
method = model_info[model_override]["fn_with_ui"]
1022+
llm_kwargs['llm_model'] = model_override
1023+
return llm_kwargs, additional_fn, method
1024+
# 默认返回原参数
1025+
return llm_kwargs, additional_fn, method
1026+
1027+
def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot,
1028+
history:list=[], system_prompt:str='', stream:bool=True, additional_fn:str=None):
10061029
"""
10071030
发送至LLM,流式获取输出。
10081031
用于基础的对话功能。
@@ -1021,6 +1044,11 @@ def predict(inputs:str, llm_kwargs:dict, *args, **kwargs):
10211044
"""
10221045

10231046
inputs = apply_gpt_academic_string_mask(inputs, mode="show_llm")
1047+
10241048
method = model_info[llm_kwargs['llm_model']]["fn_with_ui"] # 如果这里报错,检查config中的AVAIL_LLM_MODELS选项
1025-
yield from method(inputs, llm_kwargs, *args, **kwargs)
1049+
1050+
if additional_fn: # 根据基础功能区 ModelOverride 参数调整模型类型
1051+
llm_kwargs, additional_fn, method = execute_model_override(llm_kwargs, additional_fn, method)
1052+
1053+
yield from method(inputs, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, stream, additional_fn)
10261054

0 commit comments

Comments
 (0)