-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathcontext_manager.py
More file actions
309 lines (258 loc) · 9.59 KB
/
context_manager.py
File metadata and controls
309 lines (258 loc) · 9.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import sqlite3
import os
import json
from datetime import datetime
from loguru import logger
class ChatContextManager:
"""
聊天上下文管理器
负责存储和检索用户与商品之间的对话历史,使用SQLite数据库进行持久化存储。
支持按会话ID检索对话历史,以及议价次数统计。
"""
def __init__(self, max_history=100, db_path="data/chat_history.db"):
"""
初始化聊天上下文管理器
Args:
max_history: 每个对话保留的最大消息数
db_path: SQLite数据库文件路径
"""
self.max_history = max_history
self.db_path = db_path
self._init_db()
def _init_db(self):
"""初始化数据库表结构"""
# 确保数据库目录存在
db_dir = os.path.dirname(self.db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建消息表
cursor.execute('''
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
item_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
chat_id TEXT
)
''')
# 检查是否需要添加chat_id字段(兼容旧数据库)
cursor.execute("PRAGMA table_info(messages)")
columns = [column[1] for column in cursor.fetchall()]
if 'chat_id' not in columns:
cursor.execute('ALTER TABLE messages ADD COLUMN chat_id TEXT')
logger.info("已为messages表添加chat_id字段")
# 创建索引以加速查询
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_user_item ON messages (user_id, item_id)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_chat_id ON messages (chat_id)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_timestamp ON messages (timestamp)
''')
# 创建基于会话ID的议价次数表
cursor.execute('''
CREATE TABLE IF NOT EXISTS chat_bargain_counts (
chat_id TEXT PRIMARY KEY,
count INTEGER DEFAULT 0,
last_updated DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# 创建商品信息表
cursor.execute('''
CREATE TABLE IF NOT EXISTS items (
item_id TEXT PRIMARY KEY,
data TEXT NOT NULL,
price REAL,
description TEXT,
last_updated DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit()
conn.close()
logger.info(f"聊天历史数据库初始化完成: {self.db_path}")
def save_item_info(self, item_id, item_data):
"""
保存商品信息到数据库
Args:
item_id: 商品ID
item_data: 商品信息字典
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
# 从商品数据中提取有用信息
price = float(item_data.get('soldPrice', 0))
description = item_data.get('desc', '')
# 将整个商品数据转换为JSON字符串
data_json = json.dumps(item_data, ensure_ascii=False)
cursor.execute(
"""
INSERT INTO items (item_id, data, price, description, last_updated)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(item_id)
DO UPDATE SET data = ?, price = ?, description = ?, last_updated = ?
""",
(
item_id, data_json, price, description, datetime.now().isoformat(),
data_json, price, description, datetime.now().isoformat()
)
)
conn.commit()
logger.debug(f"商品信息已保存: {item_id}")
except Exception as e:
logger.error(f"保存商品信息时出错: {e}")
conn.rollback()
finally:
conn.close()
def get_item_info(self, item_id):
"""
从数据库获取商品信息
Args:
item_id: 商品ID
Returns:
dict: 商品信息字典,如果不存在返回None
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
cursor.execute(
"SELECT data FROM items WHERE item_id = ?",
(item_id,)
)
result = cursor.fetchone()
if result:
return json.loads(result[0])
return None
except Exception as e:
logger.error(f"获取商品信息时出错: {e}")
return None
finally:
conn.close()
def add_message_by_chat(self, chat_id, user_id, item_id, role, content):
"""
基于会话ID添加新消息到对话历史
Args:
chat_id: 会话ID
user_id: 用户ID (用户消息存真实user_id,助手消息存卖家ID)
item_id: 商品ID
role: 消息角色 (user/assistant)
content: 消息内容
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
# 插入新消息,使用chat_id作为额外标识
cursor.execute(
"INSERT INTO messages (user_id, item_id, role, content, timestamp, chat_id) VALUES (?, ?, ?, ?, ?, ?)",
(user_id, item_id, role, content, datetime.now().isoformat(), chat_id)
)
# 检查是否需要清理旧消息(基于chat_id)
cursor.execute(
"""
SELECT id FROM messages
WHERE chat_id = ?
ORDER BY timestamp DESC
LIMIT ?, 1
""",
(chat_id, self.max_history)
)
oldest_to_keep = cursor.fetchone()
if oldest_to_keep:
cursor.execute(
"DELETE FROM messages WHERE chat_id = ? AND id < ?",
(chat_id, oldest_to_keep[0])
)
conn.commit()
except Exception as e:
logger.error(f"添加消息到数据库时出错: {e}")
conn.rollback()
finally:
conn.close()
def get_context_by_chat(self, chat_id):
"""
基于会话ID获取对话历史
Args:
chat_id: 会话ID
Returns:
list: 包含对话历史的列表
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
cursor.execute(
"""
SELECT role, content FROM messages
WHERE chat_id = ?
ORDER BY timestamp ASC
LIMIT ?
""",
(chat_id, self.max_history)
)
messages = [{"role": role, "content": content} for role, content in cursor.fetchall()]
# 获取议价次数并添加到上下文中
bargain_count = self.get_bargain_count_by_chat(chat_id)
if bargain_count > 0:
messages.append({
"role": "system",
"content": f"议价次数: {bargain_count}"
})
except Exception as e:
logger.error(f"获取对话历史时出错: {e}")
messages = []
finally:
conn.close()
return messages
def increment_bargain_count_by_chat(self, chat_id):
"""
基于会话ID增加议价次数
Args:
chat_id: 会话ID
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
# 使用UPSERT语法直接基于chat_id增加议价次数
cursor.execute(
"""
INSERT INTO chat_bargain_counts (chat_id, count, last_updated)
VALUES (?, 1, ?)
ON CONFLICT(chat_id)
DO UPDATE SET count = count + 1, last_updated = ?
""",
(chat_id, datetime.now().isoformat(), datetime.now().isoformat())
)
conn.commit()
logger.debug(f"会话 {chat_id} 议价次数已增加")
except Exception as e:
logger.error(f"增加议价次数时出错: {e}")
conn.rollback()
finally:
conn.close()
def get_bargain_count_by_chat(self, chat_id):
"""
基于会话ID获取议价次数
Args:
chat_id: 会话ID
Returns:
int: 议价次数
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
cursor.execute(
"SELECT count FROM chat_bargain_counts WHERE chat_id = ?",
(chat_id,)
)
result = cursor.fetchone()
return result[0] if result else 0
except Exception as e:
logger.error(f"获取议价次数时出错: {e}")
return 0
finally:
conn.close()