diff --git a/ChuanhuChatbot.py b/ChuanhuChatbot.py index 4074dc1c..42f0f297 100644 --- a/ChuanhuChatbot.py +++ b/ChuanhuChatbot.py @@ -359,7 +359,7 @@ token_count, top_p, temperature, - use_streaming_checkbox, + gr.State(0), model_select_dropdown, ], [chatbot, history, status_display, token_count], diff --git a/chat_func.py b/chat_func.py index dd9a7954..374178f3 100644 --- a/chat_func.py +++ b/chat_func.py @@ -371,9 +371,8 @@ def predict( all_token_counts, top_p, temperature, - stream=False, + max_token//2, selected_model=selected_model, - hidden=True, ) for chatbot, history, status_text, all_token_counts in iter: status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}" @@ -410,9 +409,10 @@ def retry( stream=stream, selected_model=selected_model, ) - logging.info("重试完毕") + logging.info("重试中……") for x in iter: yield x + logging.info("重试完毕") def reduce_token_size( @@ -423,9 +423,8 @@ def reduce_token_size( token_count, top_p, temperature, - stream=False, + max_token_count, selected_model=MODELS[0], - hidden=False, ): logging.info("开始减少token数量……") iter = predict( @@ -437,17 +436,21 @@ def reduce_token_size( token_count, top_p, temperature, - stream=stream, selected_model=selected_model, should_check_token_count=False, ) logging.info(f"chatbot: {chatbot}") + flag = False for chatbot, history, status_text, previous_token_count in iter: - history = history[-2:] - token_count = previous_token_count[-1:] - if hidden: - chatbot.pop() - yield chatbot, history, construct_token_message( - sum(token_count), stream=stream + num_chat = find_n(previous_token_count, max_token_count) + if flag: + chatbot = chatbot[:-1] + flag = True + history = history[-2*num_chat:] if num_chat > 0 else [] + token_count = previous_token_count[-num_chat:] if num_chat > 0 else [] + msg = f"保留了最近{num_chat}轮对话" + yield chatbot, history, msg + "," + construct_token_message( + sum(token_count) if len(token_count) > 0 else 0, ), token_count + logging.info(msg) logging.info("减少token数量完毕") \ No newline at end of file diff --git a/utils.py b/utils.py index 637b2e12..d485a5ad 100644 --- a/utils.py +++ b/utils.py @@ -37,9 +37,10 @@ def count_token(message): length = len(encoding.encode(input_str)) return length + def markdown_to_html_with_syntax_highlight(md_str): def replacer(match): - lang = match.group(1) or 'text' + lang = match.group(1) or "text" code = match.group(2) try: @@ -50,60 +51,65 @@ def replacer(match): formatter = HtmlFormatter() highlighted_code = highlight(code, lexer, formatter) - return f"
{highlighted_code}
"
+ return f'{highlighted_code}
'
- code_block_pattern = r'```(\w+)?\n([\s\S]+?)\n```'
+ code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
html_str = markdown(md_str)
return html_str
+
def normalize_markdown(md_text: str) -> str:
- lines = md_text.split('\n')
+ lines = md_text.split("\n")
normalized_lines = []
inside_list = False
for i, line in enumerate(lines):
- if re.match(r'^(\d+\.|-|\*|\+)\s', line.strip()):
- if not inside_list and i > 0 and lines[i - 1].strip() != '':
- normalized_lines.append('')
+ if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
+ if not inside_list and i > 0 and lines[i - 1].strip() != "":
+ normalized_lines.append("")
inside_list = True
normalized_lines.append(line)
- elif inside_list and line.strip() == '':
- if i < len(lines) - 1 and not re.match(r'^(\d+\.|-|\*|\+)\s', lines[i + 1].strip()):
+ elif inside_list and line.strip() == "":
+ if i < len(lines) - 1 and not re.match(
+ r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
+ ):
normalized_lines.append(line)
continue
else:
inside_list = False
normalized_lines.append(line)
- return '\n'.join(normalized_lines)
+ return "\n".join(normalized_lines)
+
def convert_mdtext(md_text):
- code_block_pattern = re.compile(r'```(.*?)(?:```|$)', re.DOTALL)
+ code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
code_blocks = code_block_pattern.findall(md_text)
non_code_parts = code_block_pattern.split(md_text)[::2]
result = []
- for non_code, code in zip(non_code_parts, code_blocks + ['']):
+ for non_code, code in zip(non_code_parts, code_blocks + [""]):
if non_code.strip():
non_code = normalize_markdown(non_code)
- result.append(mdtex2html.convert(non_code, extensions=['tables']))
+ result.append(mdtex2html.convert(non_code, extensions=["tables"]))
if code.strip():
- _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
+ _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
code = f"```{code}\n\n```"
code = markdown_to_html_with_syntax_highlight(code)
result.append(code)
result = "".join(result)
return result
+
def detect_language(code):
if code.startswith("\n"):
first_line = ""
else:
- first_line = code.strip().split('\n', 1)[0]
- language = first_line.lower() if first_line else ''
- code_without_language = code[len(first_line):].lstrip() if first_line else code
+ first_line = code.strip().split("\n", 1)[0]
+ language = first_line.lower() if first_line else ""
+ code_without_language = code[len(first_line) :].lstrip() if first_line else code
return language, code_without_language
@@ -336,26 +342,40 @@ def replace_today(prompt):
today = datetime.datetime.today().strftime("%Y-%m-%d")
return prompt.replace("{current_date}", today)
+
def get_geoip():
- response = requests.get('https://ipapi.co/json/', timeout=5)
+ response = requests.get("https://ipapi.co/json/", timeout=5)
try:
data = response.json()
except:
- data = {
- "error": True,
- "reason" : "连接ipapi失败"
- }
+ data = {"error": True, "reason": "连接ipapi失败"}
if "error" in data.keys():
logging.warning(f"无法获取IP地址信息。\n{data}")
- if data['reason'] == "RateLimited":
- return f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
+ if data["reason"] == "RateLimited":
+ return (
+ f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
+ )
else:
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
else:
- country = data['country_name']
+ country = data["country_name"]
if country == "China":
text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
else:
text = f"您的IP区域:{country}。"
logging.info(text)
- return text
\ No newline at end of file
+ return text
+
+
+def find_n(lst, max_num):
+ n = len(lst)
+ total = sum(lst)
+
+ if total < max_num:
+ return n
+
+ for i in range(len(lst)):
+ if total - lst[i] < max_num:
+ return n - i -1
+ total = total - lst[i]
+ return 1