|
1 | 1 | """ |
2 | 2 | LLM客户端封装 |
3 | | -统一使用OpenAI格式调用 |
| 3 | +Supports two backends: |
| 4 | + 1. Prompture (optional) — 12+ providers: LM Studio, Ollama, Claude, Groq, Kimi, etc. |
| 5 | + 2. OpenAI SDK (default fallback) — any OpenAI-compatible API |
| 6 | +Install Prompture for multi-provider support: pip install prompture |
4 | 7 | """ |
5 | 8 |
|
6 | 9 | import json |
7 | 10 | import re |
8 | 11 | from typing import Optional, Dict, Any, List |
9 | | -from openai import OpenAI |
10 | 12 |
|
11 | 13 | from ..config import Config |
12 | 14 |
|
| 15 | +# Try to import Prompture; fall back to OpenAI SDK if not installed |
| 16 | +try: |
| 17 | + from prompture.agents import Conversation |
| 18 | + from prompture.infra.provider_env import ProviderEnvironment |
| 19 | + from prompture.extraction.tools import strip_think_tags, clean_json_text |
| 20 | + _HAS_PROMPTURE = True |
| 21 | +except ImportError: |
| 22 | + _HAS_PROMPTURE = False |
| 23 | + |
| 24 | +if not _HAS_PROMPTURE: |
| 25 | + from openai import OpenAI |
| 26 | + |
| 27 | + |
| 28 | +# Provider name → ProviderEnvironment field name |
| 29 | +_KEY_MAP = { |
| 30 | + "openai": "openai_api_key", |
| 31 | + "claude": "claude_api_key", |
| 32 | + "google": "google_api_key", |
| 33 | + "groq": "groq_api_key", |
| 34 | + "grok": "grok_api_key", |
| 35 | + "openrouter": "openrouter_api_key", |
| 36 | + "moonshot": "moonshot_api_key", |
| 37 | +} |
| 38 | + |
13 | 39 |
|
14 | 40 | class LLMClient: |
15 | | - """LLM客户端""" |
16 | | - |
| 41 | + """LLM客户端 |
| 42 | +
|
| 43 | + When Prompture is installed, ``model`` accepts the ``"provider/model"`` |
| 44 | + format for multi-provider support:: |
| 45 | +
|
| 46 | + "lmstudio/local-model" → LM Studio (free, local) |
| 47 | + "ollama/llama3.1:8b" → Ollama (free, local) |
| 48 | + "openai/gpt-4o" → OpenAI |
| 49 | + "claude/claude-sonnet-4-20250514" → Anthropic |
| 50 | + "moonshot/moonshot-v1-8k" → Kimi / Moonshot |
| 51 | + "groq/llama-3.1-70b" → Groq |
| 52 | +
|
| 53 | + Without Prompture, the original OpenAI SDK backend is used (any |
| 54 | + OpenAI-compatible API via LLM_BASE_URL). |
| 55 | + """ |
| 56 | + |
17 | 57 | def __init__( |
18 | 58 | self, |
19 | 59 | api_key: Optional[str] = None, |
20 | 60 | base_url: Optional[str] = None, |
21 | | - model: Optional[str] = None |
| 61 | + model: Optional[str] = None, |
22 | 62 | ): |
23 | 63 | self.api_key = api_key or Config.LLM_API_KEY |
24 | 64 | self.base_url = base_url or Config.LLM_BASE_URL |
25 | 65 | self.model = model or Config.LLM_MODEL_NAME |
26 | | - |
| 66 | + |
| 67 | + if _HAS_PROMPTURE: |
| 68 | + self._init_prompture() |
| 69 | + else: |
| 70 | + self._init_openai() |
| 71 | + |
| 72 | + # ── Prompture backend ────────────────────────────────────────── |
| 73 | + |
| 74 | + def _init_prompture(self): |
| 75 | + env_kwargs: Dict[str, Any] = {} |
| 76 | + if self.api_key: |
| 77 | + provider = self.model.split("/")[0] if "/" in self.model else "openai" |
| 78 | + env_field = _KEY_MAP.get(provider) |
| 79 | + if env_field: |
| 80 | + env_kwargs[env_field] = self.api_key |
| 81 | + |
| 82 | + self._env = ProviderEnvironment(**env_kwargs) if env_kwargs else None |
| 83 | + self._driver_options: Dict[str, Any] = {} |
| 84 | + if self.base_url: |
| 85 | + self._driver_options["base_url"] = self.base_url |
| 86 | + |
| 87 | + def _make_conversation(self, temperature: float, max_tokens: int) -> "Conversation": |
| 88 | + opts: Dict[str, Any] = { |
| 89 | + "temperature": temperature, |
| 90 | + "max_tokens": max_tokens, |
| 91 | + **self._driver_options, |
| 92 | + } |
| 93 | + return Conversation(self.model, options=opts, env=self._env) |
| 94 | + |
| 95 | + # ── OpenAI fallback backend ──────────────────────────────────── |
| 96 | + |
| 97 | + def _init_openai(self): |
27 | 98 | if not self.api_key: |
28 | 99 | raise ValueError("LLM_API_KEY 未配置") |
29 | | - |
30 | | - self.client = OpenAI( |
31 | | - api_key=self.api_key, |
32 | | - base_url=self.base_url |
33 | | - ) |
34 | | - |
| 100 | + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) |
| 101 | + |
| 102 | + # ── Public API ───────────────────────────────────────────────── |
| 103 | + |
35 | 104 | def chat( |
36 | 105 | self, |
37 | 106 | messages: List[Dict[str, str]], |
38 | 107 | temperature: float = 0.7, |
39 | 108 | max_tokens: int = 4096, |
40 | | - response_format: Optional[Dict] = None |
| 109 | + response_format: Optional[Dict] = None, |
41 | 110 | ) -> str: |
42 | 111 | """ |
43 | 112 | 发送聊天请求 |
44 | | - |
| 113 | +
|
45 | 114 | Args: |
46 | 115 | messages: 消息列表 |
47 | 116 | temperature: 温度参数 |
48 | 117 | max_tokens: 最大token数 |
49 | 118 | response_format: 响应格式(如JSON模式) |
50 | | - |
| 119 | +
|
51 | 120 | Returns: |
52 | 121 | 模型响应文本 |
53 | 122 | """ |
54 | | - kwargs = { |
55 | | - "model": self.model, |
56 | | - "messages": messages, |
57 | | - "temperature": temperature, |
58 | | - "max_tokens": max_tokens, |
59 | | - } |
60 | | - |
61 | | - if response_format: |
62 | | - kwargs["response_format"] = response_format |
63 | | - |
64 | | - response = self.client.chat.completions.create(**kwargs) |
65 | | - content = response.choices[0].message.content |
66 | | - # 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除 |
67 | | - content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip() |
68 | | - return content |
69 | | - |
| 123 | + if _HAS_PROMPTURE: |
| 124 | + content = self._chat_prompture(messages, temperature, max_tokens) |
| 125 | + return strip_think_tags(content) |
| 126 | + else: |
| 127 | + content = self._chat_openai(messages, temperature, max_tokens, response_format) |
| 128 | + # Fallback: strip think tags with regex when Prompture is not available |
| 129 | + return re.sub(r'<think>[\s\S]*?</think>', '', content).strip() |
| 130 | + |
70 | 131 | def chat_json( |
71 | 132 | self, |
72 | 133 | messages: List[Dict[str, str]], |
73 | 134 | temperature: float = 0.3, |
74 | | - max_tokens: int = 4096 |
| 135 | + max_tokens: int = 4096, |
75 | 136 | ) -> Dict[str, Any]: |
76 | 137 | """ |
77 | 138 | 发送聊天请求并返回JSON |
78 | | - |
| 139 | +
|
79 | 140 | Args: |
80 | 141 | messages: 消息列表 |
81 | 142 | temperature: 温度参数 |
82 | 143 | max_tokens: 最大token数 |
83 | | - |
| 144 | +
|
84 | 145 | Returns: |
85 | 146 | 解析后的JSON对象 |
86 | 147 | """ |
87 | | - response = self.chat( |
88 | | - messages=messages, |
89 | | - temperature=temperature, |
90 | | - max_tokens=max_tokens, |
91 | | - # 不設 response_format 以相容 LM Studio / Ollama 等本地模型 |
92 | | - # 依賴 prompt 中的 JSON 指示 + 下方的 markdown 清理邏輯 |
93 | | - ) |
94 | | - # 清理markdown代码块标记 |
95 | | - cleaned_response = response.strip() |
96 | | - cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE) |
97 | | - cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response) |
98 | | - cleaned_response = cleaned_response.strip() |
| 148 | + if _HAS_PROMPTURE: |
| 149 | + response = self._chat_prompture(messages, temperature, max_tokens) |
| 150 | + # Prompture's clean_json_text strips think tags + markdown fences |
| 151 | + cleaned = clean_json_text(response) |
| 152 | + else: |
| 153 | + response = self._chat_openai( |
| 154 | + messages, temperature, max_tokens |
| 155 | + ) |
| 156 | + # Fallback cleaning when Prompture is not available |
| 157 | + cleaned = re.sub(r'<think>[\s\S]*?</think>', '', response).strip() |
| 158 | + cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE) |
| 159 | + cleaned = re.sub(r'\n?```\s*$', '', cleaned) |
| 160 | + cleaned = cleaned.strip() |
99 | 161 |
|
100 | 162 | try: |
101 | | - return json.loads(cleaned_response) |
| 163 | + return json.loads(cleaned) |
102 | 164 | except json.JSONDecodeError: |
103 | | - raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}") |
| 165 | + raise ValueError(f"LLM返回的JSON格式无效: {cleaned}") |
| 166 | + |
| 167 | + # ── Private: Prompture path ──────────────────────────────────── |
| 168 | + |
| 169 | + def _chat_prompture( |
| 170 | + self, |
| 171 | + messages: List[Dict[str, str]], |
| 172 | + temperature: float, |
| 173 | + max_tokens: int, |
| 174 | + ) -> str: |
| 175 | + conv = self._make_conversation(temperature, max_tokens) |
| 176 | + |
| 177 | + # Inject system prompt |
| 178 | + system_parts = [m["content"] for m in messages if m["role"] == "system"] |
| 179 | + if system_parts: |
| 180 | + conv._messages.append({"role": "system", "content": "\n".join(system_parts)}) |
| 181 | + |
| 182 | + # Replay prior turns |
| 183 | + non_system = [m for m in messages if m["role"] != "system"] |
| 184 | + for msg in non_system[:-1]: |
| 185 | + conv._messages.append({"role": msg["role"], "content": msg["content"]}) |
| 186 | + |
| 187 | + prompt = non_system[-1]["content"] if non_system else "" |
| 188 | + return conv.ask(prompt) |
104 | 189 |
|
| 190 | + # ── Private: OpenAI fallback path ────────────────────────────── |
| 191 | + |
| 192 | + def _chat_openai( |
| 193 | + self, |
| 194 | + messages: List[Dict[str, str]], |
| 195 | + temperature: float, |
| 196 | + max_tokens: int, |
| 197 | + response_format: Optional[Dict] = None, |
| 198 | + ) -> str: |
| 199 | + kwargs = { |
| 200 | + "model": self.model, |
| 201 | + "messages": messages, |
| 202 | + "temperature": temperature, |
| 203 | + "max_tokens": max_tokens, |
| 204 | + } |
| 205 | + if response_format: |
| 206 | + kwargs["response_format"] = response_format |
| 207 | + |
| 208 | + response = self.client.chat.completions.create(**kwargs) |
| 209 | + return response.choices[0].message.content |
0 commit comments