-
Notifications
You must be signed in to change notification settings - Fork 144
/
Copy pathchat_config.py
359 lines (318 loc) · 12.2 KB
/
chat_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import os
import json
import logging
from typing import Optional
import dspy
from llama_index.llms.bedrock.utils import BEDROCK_FOUNDATION_LLMS
from pydantic import BaseModel
from llama_index.llms.openai.utils import DEFAULT_OPENAI_API_BASE
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.gemini import Gemini
from llama_index.llms.bedrock import Bedrock
from llama_index.llms.ollama import Ollama
from llama_index.core.llms.llm import LLM
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.jinaai import JinaEmbedding
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.postprocessor.jinaai_rerank import JinaRerank
from llama_index.postprocessor.cohere_rerank import CohereRerank
from sqlmodel import Session, select
from google.oauth2 import service_account
from google.auth.transport.requests import Request
from app.rag.node_postprocessor import MetadataPostFilter
from app.rag.node_postprocessor.metadata_post_filter import MetadataFilters
from app.rag.node_postprocessor.baisheng_reranker import BaishengRerank
from app.rag.node_postprocessor.local_reranker import LocalRerank
from app.rag.embeddings.local_embedding import LocalEmbedding
from app.types import LLMProvider, EmbeddingProvider, RerankerProvider
from app.rag.default_prompt import (
DEFAULT_INTENT_GRAPH_KNOWLEDGE,
DEFAULT_NORMAL_GRAPH_KNOWLEDGE,
DEFAULT_CONDENSE_QUESTION_PROMPT,
DEFAULT_TEXT_QA_PROMPT,
DEFAULT_REFINE_PROMPT, DEFAULT_FURTHER_QUESTIONS_PROMPT,
)
from app.models import (
ChatEngine as DBChatEngine,
LLM as DBLLM,
EmbeddingModel as DBEmbeddingModel,
RerankerModel as DBRerankerModel,
)
from app.repositories import chat_engine_repo
from app.rag.llms.anthropic_vertex import AnthropicVertex
from app.utils.dspy import get_dspy_lm_by_llama_llm
logger = logging.getLogger(__name__)
class LLMOption(BaseModel):
intent_graph_knowledge: str = DEFAULT_INTENT_GRAPH_KNOWLEDGE
normal_graph_knowledge: str = DEFAULT_NORMAL_GRAPH_KNOWLEDGE
condense_question_prompt: str = DEFAULT_CONDENSE_QUESTION_PROMPT
text_qa_prompt: str = DEFAULT_TEXT_QA_PROMPT
refine_prompt: str = DEFAULT_REFINE_PROMPT
further_questions_prompt: str = DEFAULT_FURTHER_QUESTIONS_PROMPT
class VectorSearchOption(BaseModel):
metadata_post_filters: Optional[MetadataFilters] = None
class KnowledgeGraphOption(BaseModel):
enabled: bool = True
depth: int = 2
include_meta: bool = True
with_degree: bool = False
using_intent_search: bool = True
relationship_meta_filters: Optional[dict] = None
class ExternalChatEngine(BaseModel):
stream_chat_api_url: str = None
type: str = 'StackVM'
class ChatEngineConfig(BaseModel):
llm: LLMOption = LLMOption()
knowledge_graph: KnowledgeGraphOption = KnowledgeGraphOption()
vector_search: VectorSearchOption = VectorSearchOption()
post_verification_url: Optional[str] = None
post_verification_token: Optional[str] = None
external_engine_config: Optional[ExternalChatEngine] = None
_db_chat_engine: Optional[DBChatEngine] = None
_db_llm: Optional[DBLLM] = None
_db_fast_llm: Optional[DBLLM] = None
_db_reranker: Optional[DBRerankerModel] = None
def get_db_chat_engine(self) -> Optional[DBChatEngine]:
return self._db_chat_engine
@classmethod
def load_from_db(cls, session: Session, engine_name: str) -> "ChatEngineConfig":
if not engine_name or engine_name == "default":
db_chat_engine = chat_engine_repo.get_default_engine(session)
else:
db_chat_engine = chat_engine_repo.get_engine_by_name(session, engine_name)
if not db_chat_engine:
logger.warning(
f"Chat engine {engine_name} not found in DB, using default engine"
)
db_chat_engine = chat_engine_repo.get_default_engine(session)
obj = cls.model_validate(db_chat_engine.engine_options)
obj._db_chat_engine = db_chat_engine
obj._db_llm = db_chat_engine.llm
obj._db_fast_llm = db_chat_engine.fast_llm
obj._db_reranker = db_chat_engine.reranker
return obj
def get_llama_llm(self, session: Session) -> LLM:
if not self._db_llm:
return get_default_llm(session)
return get_llm(
self._db_llm.provider,
self._db_llm.model,
self._db_llm.config,
self._db_llm.credentials,
)
def get_dspy_lm(self, session: Session) -> dspy.LM:
llama_llm = self.get_llama_llm(session)
return get_dspy_lm_by_llama_llm(llama_llm)
def get_fast_llama_llm(self, session: Session) -> LLM:
if not self._db_fast_llm:
return get_default_llm(session)
return get_llm(
self._db_fast_llm.provider,
self._db_fast_llm.model,
self._db_fast_llm.config,
self._db_fast_llm.credentials,
)
def get_fast_dspy_lm(self, session: Session) -> dspy.LM:
llama_llm = self.get_fast_llama_llm(session)
return get_dspy_lm_by_llama_llm(llama_llm)
def get_reranker(self, session: Session) -> Optional[BaseNodePostprocessor]:
if not self._db_reranker:
return get_default_reranker_model(session)
return get_reranker_model(
self._db_reranker.provider,
self._db_reranker.model,
self._db_reranker.top_n,
self._db_reranker.config,
self._db_reranker.credentials,
)
def get_metadata_filter(self) -> BaseNodePostprocessor:
return get_metadata_post_filter(self.vector_search.metadata_post_filters)
def screenshot(self) -> dict:
return self.model_dump(
exclude={
"llm": [
"condense_question_prompt",
"text_qa_prompt",
"refine_prompt",
],
"post_verification_token": True,
}
)
def get_llm(
provider: LLMProvider,
model: str,
config: dict,
credentials: str | list | dict | None,
) -> LLM:
match provider:
case LLMProvider.OPENAI:
api_base = config.pop("api_base", DEFAULT_OPENAI_API_BASE)
return OpenAI(
model=model,
api_base=api_base,
api_key=credentials,
**config,
)
case LLMProvider.OPENAI_LIKE:
llm = OpenAILike(model=model, api_key=credentials, **config)
llm.context_window = 200000
return llm
case LLMProvider.GEMINI:
os.environ["GOOGLE_API_KEY"] = credentials
return Gemini(model=model, api_key=credentials, **config)
case LLMProvider.BEDROCK:
access_key_id = credentials["aws_access_key_id"]
secret_access_key = credentials["aws_secret_access_key"]
region_name = credentials["aws_region_name"]
context_size = None
if model not in BEDROCK_FOUNDATION_LLMS:
context_size = 200000
llm = Bedrock(
model=model,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
region_name=region_name,
context_size=context_size,
)
# Note: Because llama index Bedrock class doesn't set up these values to the corresponding
# attributes in its constructor function, we pass the values again via setter to pass them to
# `get_dspy_lm_by_llama_llm` function.
llm.aws_access_key_id = access_key_id
llm.aws_secret_access_key = secret_access_key
llm.region_name = region_name
return llm
case LLMProvider.ANTHROPIC_VERTEX:
google_creds: service_account.Credentials = (
service_account.Credentials.from_service_account_info(
credentials,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
google_creds.refresh(request=Request())
if "max_tokens" not in config:
config.update(max_tokens=4096)
return AnthropicVertex(model=model, credentials=google_creds, **config)
case LLMProvider.OLLAMA:
config.setdefault("request_timeout", 60 * 10)
config.setdefault("context_window", 4096)
return Ollama(model=model, **config)
case _:
raise ValueError(f"Got unknown LLM provider: {provider}")
def get_default_llm(session: Session) -> LLM:
db_llm = session.exec(
select(DBLLM).order_by(DBLLM.is_default.desc()).limit(1)
).first()
if not db_llm:
raise ValueError("No default LLM found in DB")
return get_llm(
db_llm.provider,
db_llm.model,
db_llm.config,
db_llm.credentials,
)
def get_embedding_model(
provider: EmbeddingProvider,
model: str,
config: dict,
credentials: str | list | dict | None,
) -> BaseEmbedding:
match provider:
case EmbeddingProvider.OPENAI:
api_base = config.pop("api_base", DEFAULT_OPENAI_API_BASE)
return OpenAIEmbedding(
model=model,
api_base=api_base,
api_key=credentials,
**config,
)
case EmbeddingProvider.JINA:
return JinaEmbedding(
model=model,
api_key=credentials,
**config,
)
case EmbeddingProvider.COHERE:
return CohereEmbedding(
model_name=model,
cohere_api_key=credentials,
)
case EmbeddingProvider.OLLAMA:
return OllamaEmbedding(
model_name=model,
**config,
)
case EmbeddingProvider.LOCAL:
return LocalEmbedding(
model=model,
**config,
)
case _:
raise ValueError(f"Got unknown embedding provider: {provider}")
def get_default_embedding_model(session: Session) -> BaseEmbedding:
db_embedding_model = session.exec(
select(DBEmbeddingModel).order_by(DBEmbeddingModel.is_default.desc()).limit(1)
).first()
if not db_embedding_model:
raise ValueError("No default embedding model found in DB")
return get_embedding_model(
db_embedding_model.provider,
db_embedding_model.model,
db_embedding_model.config,
db_embedding_model.credentials,
)
def get_reranker_model(
provider: RerankerProvider,
model: str,
top_n: int,
config: dict,
credentials: str | list | dict | None,
) -> BaseNodePostprocessor:
match provider:
case RerankerProvider.JINA:
return JinaRerank(
model=model,
top_n=top_n,
api_key=credentials,
)
case RerankerProvider.COHERE:
return CohereRerank(
model=model,
top_n=top_n,
api_key=credentials,
)
case RerankerProvider.BAISHENG:
return BaishengRerank(
model=model,
top_n=top_n,
api_key=credentials,
**config,
)
case RerankerProvider.LOCAL:
return LocalRerank(
model=model,
top_n=top_n,
**config,
)
case _:
raise ValueError(f"Got unknown reranker provider: {provider}")
def get_default_reranker_model(session: Session) -> Optional[BaseNodePostprocessor]:
db_reranker = session.exec(
select(DBRerankerModel).order_by(DBRerankerModel.is_default.desc()).limit(1)
).first()
if not db_reranker:
return None
return get_reranker_model(
db_reranker.provider,
db_reranker.model,
db_reranker.top_n,
db_reranker.config,
db_reranker.credentials,
)
def get_metadata_post_filter(
filters: Optional[MetadataFilters] = None,
) -> BaseNodePostprocessor:
return MetadataPostFilter(filters)