From a2dfe6a5f563ff829a69c1986bc60b79844fdd5f Mon Sep 17 00:00:00 2001 From: Tuchuanhuhuhu Date: Wed, 22 Mar 2023 11:40:41 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9B=E5=87=8F=E5=B0=91token?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ChuanhuChatbot.py | 2 +- chat_func.py | 27 ++++++++++-------- utils.py | 72 ++++++++++++++++++++++++++++++----------------- 3 files changed, 62 insertions(+), 39 deletions(-) diff --git a/ChuanhuChatbot.py b/ChuanhuChatbot.py index 4074dc1c..42f0f297 100644 --- a/ChuanhuChatbot.py +++ b/ChuanhuChatbot.py @@ -359,7 +359,7 @@ token_count, top_p, temperature, - use_streaming_checkbox, + gr.State(0), model_select_dropdown, ], [chatbot, history, status_display, token_count], diff --git a/chat_func.py b/chat_func.py index dd9a7954..374178f3 100644 --- a/chat_func.py +++ b/chat_func.py @@ -371,9 +371,8 @@ def predict( all_token_counts, top_p, temperature, - stream=False, + max_token//2, selected_model=selected_model, - hidden=True, ) for chatbot, history, status_text, all_token_counts in iter: status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}" @@ -410,9 +409,10 @@ def retry( stream=stream, selected_model=selected_model, ) - logging.info("重试完毕") + logging.info("重试中……") for x in iter: yield x + logging.info("重试完毕") def reduce_token_size( @@ -423,9 +423,8 @@ def reduce_token_size( token_count, top_p, temperature, - stream=False, + max_token_count, selected_model=MODELS[0], - hidden=False, ): logging.info("开始减少token数量……") iter = predict( @@ -437,17 +436,21 @@ def reduce_token_size( token_count, top_p, temperature, - stream=stream, selected_model=selected_model, should_check_token_count=False, ) logging.info(f"chatbot: {chatbot}") + flag = False for chatbot, history, status_text, previous_token_count in iter: - history = history[-2:] - token_count = previous_token_count[-1:] - if hidden: - chatbot.pop() - yield chatbot, history, construct_token_message( - sum(token_count), stream=stream + num_chat = find_n(previous_token_count, max_token_count) + if flag: + chatbot = chatbot[:-1] + flag = True + history = history[-2*num_chat:] if num_chat > 0 else [] + token_count = previous_token_count[-num_chat:] if num_chat > 0 else [] + msg = f"保留了最近{num_chat}轮对话" + yield chatbot, history, msg + "," + construct_token_message( + sum(token_count) if len(token_count) > 0 else 0, ), token_count + logging.info(msg) logging.info("减少token数量完毕") \ No newline at end of file diff --git a/utils.py b/utils.py index 637b2e12..d485a5ad 100644 --- a/utils.py +++ b/utils.py @@ -37,9 +37,10 @@ def count_token(message): length = len(encoding.encode(input_str)) return length + def markdown_to_html_with_syntax_highlight(md_str): def replacer(match): - lang = match.group(1) or 'text' + lang = match.group(1) or "text" code = match.group(2) try: @@ -50,60 +51,65 @@ def replacer(match): formatter = HtmlFormatter() highlighted_code = highlight(code, lexer, formatter) - return f"
{highlighted_code}
" + return f'
{highlighted_code}
' - code_block_pattern = r'```(\w+)?\n([\s\S]+?)\n```' + code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```" md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE) html_str = markdown(md_str) return html_str + def normalize_markdown(md_text: str) -> str: - lines = md_text.split('\n') + lines = md_text.split("\n") normalized_lines = [] inside_list = False for i, line in enumerate(lines): - if re.match(r'^(\d+\.|-|\*|\+)\s', line.strip()): - if not inside_list and i > 0 and lines[i - 1].strip() != '': - normalized_lines.append('') + if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()): + if not inside_list and i > 0 and lines[i - 1].strip() != "": + normalized_lines.append("") inside_list = True normalized_lines.append(line) - elif inside_list and line.strip() == '': - if i < len(lines) - 1 and not re.match(r'^(\d+\.|-|\*|\+)\s', lines[i + 1].strip()): + elif inside_list and line.strip() == "": + if i < len(lines) - 1 and not re.match( + r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip() + ): normalized_lines.append(line) continue else: inside_list = False normalized_lines.append(line) - return '\n'.join(normalized_lines) + return "\n".join(normalized_lines) + def convert_mdtext(md_text): - code_block_pattern = re.compile(r'```(.*?)(?:```|$)', re.DOTALL) + code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL) code_blocks = code_block_pattern.findall(md_text) non_code_parts = code_block_pattern.split(md_text)[::2] result = [] - for non_code, code in zip(non_code_parts, code_blocks + ['']): + for non_code, code in zip(non_code_parts, code_blocks + [""]): if non_code.strip(): non_code = normalize_markdown(non_code) - result.append(mdtex2html.convert(non_code, extensions=['tables'])) + result.append(mdtex2html.convert(non_code, extensions=["tables"])) if code.strip(): - _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题 + _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题 code = f"```{code}\n\n```" code = markdown_to_html_with_syntax_highlight(code) result.append(code) result = "".join(result) return result + def detect_language(code): if code.startswith("\n"): first_line = "" else: - first_line = code.strip().split('\n', 1)[0] - language = first_line.lower() if first_line else '' - code_without_language = code[len(first_line):].lstrip() if first_line else code + first_line = code.strip().split("\n", 1)[0] + language = first_line.lower() if first_line else "" + code_without_language = code[len(first_line) :].lstrip() if first_line else code return language, code_without_language @@ -336,26 +342,40 @@ def replace_today(prompt): today = datetime.datetime.today().strftime("%Y-%m-%d") return prompt.replace("{current_date}", today) + def get_geoip(): - response = requests.get('https://ipapi.co/json/', timeout=5) + response = requests.get("https://ipapi.co/json/", timeout=5) try: data = response.json() except: - data = { - "error": True, - "reason" : "连接ipapi失败" - } + data = {"error": True, "reason": "连接ipapi失败"} if "error" in data.keys(): logging.warning(f"无法获取IP地址信息。\n{data}") - if data['reason'] == "RateLimited": - return f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。" + if data["reason"] == "RateLimited": + return ( + f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。" + ) else: return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。" else: - country = data['country_name'] + country = data["country_name"] if country == "China": text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**" else: text = f"您的IP区域:{country}。" logging.info(text) - return text \ No newline at end of file + return text + + +def find_n(lst, max_num): + n = len(lst) + total = sum(lst) + + if total < max_num: + return n + + for i in range(len(lst)): + if total - lst[i] < max_num: + return n - i -1 + total = total - lst[i] + return 1