-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathutils.py
More file actions
180 lines (149 loc) · 5.57 KB
/
utils.py
File metadata and controls
180 lines (149 loc) · 5.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import json
import re
import os
import asyncio
from typing import Any, Dict, List, Tuple
from openai import AsyncOpenAI
from config import LLM_API_KEY, LLM_BASE_URL, MODEL_NAME, EMBED_MODEL_NAME
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
Returns:
str: Hash string
"""
import hashlib
# Convert all arguments to strings and join them
args_str = "".join([str(arg) for arg in args])
if cache_type:
args_str = f"{cache_type}:{args_str}"
# Compute MD5 hash
return hashlib.md5(args_str.encode()).hexdigest()
async def openai_complete(
prompt,
model="gpt-4o",
system_prompt=None,
history_messages=[],
temperature=0.0,
**kwargs,
) -> str:
model = MODEL_NAME
openai_async_client = AsyncOpenAI(
api_key=LLM_API_KEY, base_url=LLM_BASE_URL
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, temperature=temperature, **kwargs
)
return response.choices[0].message.content
def load_vdb(dataset, documents):
import faiss
from sentence_transformers import SentenceTransformer
index_path = f"./db/vdb/{dataset}_index.faiss"
embed_model = SentenceTransformer(EMBED_MODEL_NAME, trust_remote_code=True)
if os.path.exists(index_path):
index = faiss.read_index(index_path)
return index, embed_model
embeddings = embed_model.encode(documents, show_progress_bar=True)
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)
os.makedirs(os.path.dirname(index_path), exist_ok=True)
faiss.write_index(index, index_path)
return index, embed_model
def vdb_retrieve(question, documents, index, embed_model, top_k):
query_embedding = embed_model.encode([question])
_, I = index.search(query_embedding, k=top_k)
retrieved_docs = [documents[i] for i in I[0]]
return retrieved_docs
def extract_words_str(text):
return ' '.join(re.findall(r'[A-Za-z]+', text))
def format_history_context(history):
history_context_str = ""
for i, (q, ctx_sum, a) in enumerate(history):
history_context_str += f"Sub-query {i + 1}: {q}\nRetrieved context:\n{ctx_sum}\nSub-query answer: {a}\n\n"
return history_context_str.strip()
def truncate_str_by_token_size(
text: str,
max_token_size: int,
tokenizer
) -> str:
if max_token_size <= 0:
return ""
encoded = tokenizer.encode(text)
truncated = encoded[:max_token_size]
return tokenizer.decode(truncated, skip_special_tokens=True)
def normalize(text: str) -> str:
import string
"""
Normalize a given string by applying the following transformations:
1. Convert the string to lowercase.
2. Remove punctuation characters.
3. Remove the articles "a", "an", and "the".
4. Normalize whitespace by collapsing multiple spaces into one.
Args:
text (str): The input string to be normalized.
Returns:
str: The normalized string.
"""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(text)))).split()
def parse_expanded_queries(query_expansion_result: str):
import ast
"""
Try to extract and parse a Python-style list of strings from
the model output, even if surrounded by extra text.
"""
text = query_expansion_result.strip()
# Step 1️⃣: 尝试直接安全解析
try:
parsed = ast.literal_eval(text)
if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed):
return parsed
except Exception:
pass
# Step 2️⃣: 用正则在文本中提取形如 ["...", "..."] 的部分
match = re.search(r'\[[\s\S]*?\]', text)
if match:
list_str = match.group(0)
try:
parsed = ast.literal_eval(list_str)
if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed):
return parsed
except Exception:
pass
# Step 3️⃣: 如果都失败,就退化为单元素列表
return [text]
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
except RuntimeError:
# If no event loop exists or it is closed, create a new one
# logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop