-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_utils.py
More file actions
135 lines (109 loc) · 3.6 KB
/
Copy pathllm_utils.py
File metadata and controls
135 lines (109 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
大模型交互工具模块 - 调用 LLM API 生成 SQL
"""
import requests
from config import LLM_CONFIG
from prompts import (
build_sql_generation_prompt,
build_sql_fix_prompt,
SYSTEM_PROMPT,
)
def _call_openai_compatible_api(messages: list, **kwargs) -> str:
"""
调用兼容 OpenAI API 格式的大模型接口
Args:
messages: 对话消息列表
**kwargs: 额外的请求参数
Returns:
模型返回的文本内容
"""
headers = {
"Authorization": f"Bearer {LLM_CONFIG['api_key']}",
"Content-Type": "application/json",
}
payload = {
"model": kwargs.get("model", LLM_CONFIG["model"]),
"messages": messages,
"temperature": kwargs.get("temperature", LLM_CONFIG["temperature"]),
"max_tokens": kwargs.get("max_tokens", LLM_CONFIG["max_tokens"]),
}
response = requests.post(
url=f"{LLM_CONFIG['api_base'].rstrip('/')}/chat/completions",
headers=headers,
json=payload,
timeout=120,
)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"].strip()
def generate_sql(user_question: str, schema_desc: str) -> str:
"""
根据用户自然语言问题生成 SQL 语句
Args:
user_question: 用户的自然语言问题
schema_desc: 数据库表结构描述
Returns:
生成的 SQL 语句
"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": build_sql_generation_prompt(user_question, schema_desc),
},
]
sql = _call_openai_compatible_api(messages)
# 清理可能的 markdown 标记
sql = sql.strip()
if sql.startswith("```sql"):
sql = sql[6:]
elif sql.startswith("```"):
sql = sql[3:]
if sql.endswith("```"):
sql = sql[:-3]
return sql.strip()
def fix_sql(sql: str, error_msg: str, schema_desc: str, original_question: str) -> str:
"""
当 SQL 执行报错时,用 LLM 修正 SQL
Args:
sql: 之前生成的错误 SQL
error_msg: 数据库返回的错误信息
schema_desc: 表结构描述
original_question: 用户最初的问题
Returns:
修正后的 SQL 语句
"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": build_sql_fix_prompt(sql, error_msg, schema_desc, original_question),
},
]
sql_fixed = _call_openai_compatible_api(messages)
# 清理可能的 markdown 标记
sql_fixed = sql_fixed.strip()
if sql_fixed.startswith("```sql"):
sql_fixed = sql_fixed[6:]
elif sql_fixed.startswith("```"):
sql_fixed = sql_fixed[3:]
if sql_fixed.endswith("```"):
sql_fixed = sql_fixed[:-3]
return sql_fixed.strip()
def test_llm_connection() -> tuple:
"""
测试 LLM API 连接是否正常
Returns:
(bool, str) 成功返回 (True, 成功消息),失败返回 (False, 错误消息)
"""
try:
messages = [
{"role": "user", "content": "回复'Hello'表示连接成功,只回复Hello三个字,不要加其他内容。"}
]
result = _call_openai_compatible_api(messages, max_tokens=64)
if "Hello" in result or "hello" in result:
return True, f"LLM API 连接成功!使用的模型: {LLM_CONFIG['model']}"
else:
return True, f"LLM API 响应正常,返回: {result[:50]}"
except Exception as e:
return False, f"LLM API 连接失败: {str(e)}"