Skip to content

Commit

Permalink
改进减少token逻辑
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Mar 22, 2023
1 parent b0a1d94 commit a2dfe6a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 39 deletions.
2 changes: 1 addition & 1 deletion ChuanhuChatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@
token_count,
top_p,
temperature,
use_streaming_checkbox,
gr.State(0),
model_select_dropdown,
],
[chatbot, history, status_display, token_count],
Expand Down
27 changes: 15 additions & 12 deletions chat_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,8 @@ def predict(
all_token_counts,
top_p,
temperature,
stream=False,
max_token//2,
selected_model=selected_model,
hidden=True,
)
for chatbot, history, status_text, all_token_counts in iter:
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
Expand Down Expand Up @@ -410,9 +409,10 @@ def retry(
stream=stream,
selected_model=selected_model,
)
logging.info("重试完毕")
logging.info("重试中……")
for x in iter:
yield x
logging.info("重试完毕")


def reduce_token_size(
Expand All @@ -423,9 +423,8 @@ def reduce_token_size(
token_count,
top_p,
temperature,
stream=False,
max_token_count,
selected_model=MODELS[0],
hidden=False,
):
logging.info("开始减少token数量……")
iter = predict(
Expand All @@ -437,17 +436,21 @@ def reduce_token_size(
token_count,
top_p,
temperature,
stream=stream,
selected_model=selected_model,
should_check_token_count=False,
)
logging.info(f"chatbot: {chatbot}")
flag = False
for chatbot, history, status_text, previous_token_count in iter:
history = history[-2:]
token_count = previous_token_count[-1:]
if hidden:
chatbot.pop()
yield chatbot, history, construct_token_message(
sum(token_count), stream=stream
num_chat = find_n(previous_token_count, max_token_count)
if flag:
chatbot = chatbot[:-1]
flag = True
history = history[-2*num_chat:] if num_chat > 0 else []
token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
msg = f"保留了最近{num_chat}轮对话"
yield chatbot, history, msg + "," + construct_token_message(
sum(token_count) if len(token_count) > 0 else 0,
), token_count
logging.info(msg)
logging.info("减少token数量完毕")
72 changes: 46 additions & 26 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def count_token(message):
length = len(encoding.encode(input_str))
return length


def markdown_to_html_with_syntax_highlight(md_str):
def replacer(match):
lang = match.group(1) or 'text'
lang = match.group(1) or "text"
code = match.group(2)

try:
Expand All @@ -50,60 +51,65 @@ def replacer(match):
formatter = HtmlFormatter()
highlighted_code = highlight(code, lexer, formatter)

return f"<pre><code class=\"{lang}\">{highlighted_code}</code></pre>"
return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'

code_block_pattern = r'```(\w+)?\n([\s\S]+?)\n```'
code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)

html_str = markdown(md_str)
return html_str


def normalize_markdown(md_text: str) -> str:
lines = md_text.split('\n')
lines = md_text.split("\n")
normalized_lines = []
inside_list = False

for i, line in enumerate(lines):
if re.match(r'^(\d+\.|-|\*|\+)\s', line.strip()):
if not inside_list and i > 0 and lines[i - 1].strip() != '':
normalized_lines.append('')
if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
if not inside_list and i > 0 and lines[i - 1].strip() != "":
normalized_lines.append("")
inside_list = True
normalized_lines.append(line)
elif inside_list and line.strip() == '':
if i < len(lines) - 1 and not re.match(r'^(\d+\.|-|\*|\+)\s', lines[i + 1].strip()):
elif inside_list and line.strip() == "":
if i < len(lines) - 1 and not re.match(
r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
):
normalized_lines.append(line)
continue
else:
inside_list = False
normalized_lines.append(line)

return '\n'.join(normalized_lines)
return "\n".join(normalized_lines)


def convert_mdtext(md_text):
code_block_pattern = re.compile(r'```(.*?)(?:```|$)', re.DOTALL)
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
code_blocks = code_block_pattern.findall(md_text)
non_code_parts = code_block_pattern.split(md_text)[::2]

result = []
for non_code, code in zip(non_code_parts, code_blocks + ['']):
for non_code, code in zip(non_code_parts, code_blocks + [""]):
if non_code.strip():
non_code = normalize_markdown(non_code)
result.append(mdtex2html.convert(non_code, extensions=['tables']))
result.append(mdtex2html.convert(non_code, extensions=["tables"]))
if code.strip():
_, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
_, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
code = f"```{code}\n\n```"
code = markdown_to_html_with_syntax_highlight(code)
result.append(code)
result = "".join(result)
return result


def detect_language(code):
if code.startswith("\n"):
first_line = ""
else:
first_line = code.strip().split('\n', 1)[0]
language = first_line.lower() if first_line else ''
code_without_language = code[len(first_line):].lstrip() if first_line else code
first_line = code.strip().split("\n", 1)[0]
language = first_line.lower() if first_line else ""
code_without_language = code[len(first_line) :].lstrip() if first_line else code
return language, code_without_language


Expand Down Expand Up @@ -336,26 +342,40 @@ def replace_today(prompt):
today = datetime.datetime.today().strftime("%Y-%m-%d")
return prompt.replace("{current_date}", today)


def get_geoip():
response = requests.get('https://ipapi.co/json/', timeout=5)
response = requests.get("https://ipapi.co/json/", timeout=5)
try:
data = response.json()
except:
data = {
"error": True,
"reason" : "连接ipapi失败"
}
data = {"error": True, "reason": "连接ipapi失败"}
if "error" in data.keys():
logging.warning(f"无法获取IP地址信息。\n{data}")
if data['reason'] == "RateLimited":
return f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
if data["reason"] == "RateLimited":
return (
f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
)
else:
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
else:
country = data['country_name']
country = data["country_name"]
if country == "China":
text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
else:
text = f"您的IP区域:{country}。"
logging.info(text)
return text
return text


def find_n(lst, max_num):
n = len(lst)
total = sum(lst)

if total < max_num:
return n

for i in range(len(lst)):
if total - lst[i] < max_num:
return n - i -1
total = total - lst[i]
return 1

0 comments on commit a2dfe6a

Please sign in to comment.