From 6d7fb20b5a16b253bee33ce24c4589c600a6c6e4 Mon Sep 17 00:00:00 2001 From: qingxu fu <505030475@qq.com> Date: Thu, 30 Mar 2023 11:05:38 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E9=85=8D=E7=BD=AE=E7=9A=84?= =?UTF-8?q?=E8=AF=BB=E5=8F=96=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- check_proxy.py | 7 ++++--- main.py | 7 ++++--- predict.py | 8 +++++--- toolbox.py | 24 ++++++++++++++++++------ 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/check_proxy.py b/check_proxy.py index 39c89728cc..a6919dd37a 100644 --- a/check_proxy.py +++ b/check_proxy.py @@ -21,6 +21,7 @@ def check_proxy(proxies): if __name__ == '__main__': import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染 - try: from config_private import proxies # 放自己的秘密如API和代理网址 os.path.exists('config_private.py') - except: from config import proxies - check_proxy(proxies) \ No newline at end of file + from toolbox import get_conf + proxies, = get_conf('proxies') + check_proxy(proxies) + \ No newline at end of file diff --git a/main.py b/main.py index cde222939e..4217304c4f 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,12 @@ import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染 import gradio as gr from predict import predict -from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated +from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到 -try: from config_private import proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION -except: from config import proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION +proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION = \ + get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION') + # 如果WEB_PORT是-1, 则随机选取WEB端口 PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT diff --git a/predict.py b/predict.py index 55a25e6093..712cbd8f60 100644 --- a/predict.py +++ b/predict.py @@ -20,10 +20,12 @@ # config_private.py放自己的秘密如API和代理网址 # 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件 -try: from config_private import proxies, API_URL, API_KEY, TIMEOUT_SECONDS, MAX_RETRY, LLM_MODEL -except: from config import proxies, API_URL, API_KEY, TIMEOUT_SECONDS, MAX_RETRY, LLM_MODEL +from toolbox import get_conf +proxies, API_URL, API_KEY, TIMEOUT_SECONDS, MAX_RETRY, LLM_MODEL = \ + get_conf('proxies', 'API_URL', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'LLM_MODEL') -timeout_bot_msg = '[local] Request timeout, network error. please check proxy settings in config.py.' +timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \ + '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。' def get_full_error(chunk, stream_response): """ diff --git a/toolbox.py b/toolbox.py index f0ec566893..be7e7cac00 100644 --- a/toolbox.py +++ b/toolbox.py @@ -1,4 +1,4 @@ -import markdown, mdtex2html, threading +import markdown, mdtex2html, threading, importlib, traceback from show_math import convert as convert_math from functools import wraps @@ -7,9 +7,9 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp 调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断 """ import time - try: from config_private import TIMEOUT_SECONDS, MAX_RETRY - except: from config import TIMEOUT_SECONDS, MAX_RETRY from predict import predict_no_ui + from toolbox import get_conf + TIMEOUT_SECONDS, MAX_RETRY = get_conf('TIMEOUT_SECONDS', 'MAX_RETRY') # 多线程的时候,需要一个mutable结构在不同线程之间传递信息 # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息 mutable = [None, ''] @@ -80,10 +80,9 @@ def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PO try: yield from f(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT) except Exception as e: - import traceback from check_proxy import check_proxy - try: from config_private import proxies - except: from config import proxies + from toolbox import get_conf + proxies, = get_conf('proxies') tb_str = regular_txt_to_markdown(traceback.format_exc()) chatbot[-1] = (chatbot[-1][0], f"[Local Message] 实验性函数调用出错: \n\n {tb_str} \n\n 当前代理可用性: \n\n {check_proxy(proxies)}") yield chatbot, history, f'异常 {e}' @@ -218,3 +217,16 @@ def on_report_generated(files, chatbot): # files.extend(report_files) chatbot.append(['汇总报告如何远程获取?', '汇总报告已经添加到右侧文件上传区,请查收。']) return report_files, chatbot + +def get_conf(*args): + # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到 + res = [] + for arg in args: + try: r = getattr(importlib.import_module('config_private'), arg) + except: r = getattr(importlib.import_module('config'), arg) + res.append(r) + # 在读取API_KEY时,检查一下是不是忘了改config + if arg=='API_KEY' and len(r) != 51: + assert False, "正确的API_KEY密钥是51位,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \ + "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)" + return res