Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions brevia/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from langchain.chains.llm import LLMChain
from langchain_core.prompts.loading import load_prompt_from_config
from pydantic import BaseModel
from brevia.models import load_chatmodel
from brevia.models import load_chatmodel, get_model_config
from brevia.settings import get_settings


class CompletionParams(BaseModel):
""" Q&A basic conversation chain params """
prompt: dict | None = None
config: dict | None = None


def load_custom_prompt(prompt: dict | None):
Expand All @@ -28,10 +29,14 @@ def simple_completion_chain(
"""

settings = get_settings()
llm_conf = settings.qa_completion_llm.copy()
comp_llm = load_chatmodel(llm_conf)
verbose = settings.verbose_mode
# Create chain for follow-up question using chat history (if present)

llm_conf = get_model_config(
'qa_completion_llm',
user_config=completion_params.config,
default=settings.qa_completion_llm.copy()
)
comp_llm = load_chatmodel(llm_conf)
completion_llm = LLMChain(
llm=comp_llm,
prompt=load_custom_prompt(completion_params.prompt),
Expand Down
27 changes: 27 additions & 0 deletions brevia/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities to create langchain LLM and Chat Model instances."""
from abc import ABC, abstractmethod
from glom import glom
from typing import Any
from langchain.chat_models.base import init_chat_model
from langchain_core.embeddings import Embeddings
Expand Down Expand Up @@ -29,6 +30,32 @@ def get_token_ids(self, text: str) -> list[int]:
}"""


def get_model_config(
key: str,
user_config: dict | None = None,
db_metadata: dict | None = None,
default: object = None,
) -> object:
"""
Retrieve a model configuration value by searching in user config, db metadata,
and settings, in order. Uses glom for safe nested lookup.
"""
# Check user config and db metadata
for source in [user_config, db_metadata]:
if source is not None:
value = glom(source, key, default=None)
if value is not None:
return value

# Check settings only if it's a known key
settings = get_settings()
if hasattr(settings, key):
value = getattr(settings, key)
return value.copy() if hasattr(value, "copy") else value

return default


def load_llm(config: dict) -> BaseLLM:
"""Load langchain LLM, use Fake LLM in test mode"""
if test_models_in_use():
Expand Down
36 changes: 25 additions & 11 deletions brevia/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic import BaseModel
from brevia.connection import connection_string
from brevia.collections import single_collection_by_name
from brevia.models import load_chatmodel, load_embeddings
from brevia.models import load_chatmodel, load_embeddings, get_model_config
from brevia.prompts import load_qa_prompt, load_condense_prompt
from brevia.settings import get_settings
from brevia.utilities.types import load_type
Expand Down Expand Up @@ -67,6 +67,7 @@ class ChatParams(BaseModel):
filter: dict[str, str | dict] | None = None
source_docs: bool = False
multiquery: bool = False
config: dict | None = None
search_type: str = "similarity"
score_threshold: float = 0.0

Expand Down Expand Up @@ -217,6 +218,8 @@ def conversation_rag_chain(
- search_type (str): Type of search algorithm to use (def is 'similarity').
- score_threshold (float): Threshold for filtering documents based on
relevance scores (default is 0.0).
- config (dict | None): Optional configuration dict that can contain
completion_llm and followup_llm configs to override defaults.
answer_callbacks (list[BaseCallbackHandler] | None): List of callback handlers
for the final LLM answer to enable streaming (default is None).

Expand All @@ -232,10 +235,11 @@ def conversation_rag_chain(
prompts = collection.cmetadata.get('prompts', {})
prompts = prompts if prompts else {}

# Main LLM configuration
qa_llm_conf = collection.cmetadata.get(
# Main LLM configuration using get_model_config
qa_llm_conf = get_model_config(
'qa_completion_llm',
dict(settings.qa_completion_llm).copy()
user_config=chat_params.config,
db_metadata=collection.cmetadata
)
qa_llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks
qa_llm_conf['streaming'] = chat_params.streaming
Expand All @@ -248,10 +252,11 @@ def conversation_rag_chain(
llm=chatllm
)

# Chain to rewrite question with history
fup_llm_conf = collection.cmetadata.get(
# Chain to rewrite question with history using get_model_config
fup_llm_conf = get_model_config(
'qa_followup_llm',
dict(settings.qa_followup_llm).copy()
user_config=chat_params.config,
db_metadata=collection.cmetadata
)
fup_llm = load_chatmodel(fup_llm_conf)
fup_chain = (
Expand Down Expand Up @@ -292,18 +297,27 @@ def conversation_chain(
collection or dataset.
"""

settings = get_settings()

# Chain to rewrite question with history
fup_llm_conf = dict(settings.qa_followup_llm).copy()
# Check if followup_llm config is provided in chat_params
fup_llm_conf = get_model_config(
'qa_followup_llm',
user_config=chat_params.config
)

fup_llm = load_chatmodel(fup_llm_conf)
fup_chain = (
load_condense_prompt()
| fup_llm
| StrOutputParser()
)

llm_conf = dict(settings.qa_completion_llm).copy()
# Main LLM configuration
# Check if completion_llm config is provided in chat_params
llm_conf = get_model_config(
'qa_completion_llm',
user_config=chat_params.config
)

llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks
llm_conf['streaming'] = chat_params.streaming

Expand Down
1 change: 1 addition & 0 deletions brevia/routers/completion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class CompletionBody(CompletionParams):
""" /completion request body """
text: str
prompt: dict | None = None
config: dict | None = None
token_data: bool = False


Expand Down
3 changes: 2 additions & 1 deletion brevia/routers/qa_router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""API endpoints for question answering and search"""
import asyncio
from glom import glom
from typing import Annotated
from typing_extensions import Self
from glom import glom
from pydantic import Field, model_validator
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains.base import Chain
Expand Down Expand Up @@ -39,6 +39,7 @@ class ChatBody(ChatParams):
collection: str | None = None
chat_history: list = []
chat_lang: str | None = None
config: dict | None = None
mode: str = Field(pattern='^(rag|conversation)$', default='rag')
token_data: bool = False

Expand Down
41 changes: 41 additions & 0 deletions tests/test_completions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Query module tests"""
from langchain.chains.base import Chain
from brevia.completions import simple_completion_chain, CompletionParams
from brevia.models import get_model_config


fake_prompt = CompletionParams()
Expand All @@ -16,3 +17,43 @@ def test_simple_completion_chain():
result = simple_completion_chain(fake_prompt)
assert result is not None
assert isinstance(result, Chain)


def test_get_model_config():
"""Test get_model_config functionality"""
# Test data using qa_completion_llm which exists in settings
test_key = "qa_completion_llm"
test_user_config = {
"qa_completion_llm": {"model": "from_user"}
}
test_db_config = {
"qa_completion_llm": {"model": "from_db"}
}
test_default = {"model": "default"}

# Test user config priority
result = get_model_config(
test_key,
user_config=test_user_config,
db_metadata=test_db_config,
default=test_default
)
assert result == {"model": "from_user"}

# Test db config fallback
result = get_model_config(
test_key,
user_config=None,
db_metadata=test_db_config,
default=test_default
)
assert result == {"model": "from_db"}

# Test default fallback when key not in settings
result = get_model_config(
"non_existent_key",
user_config=None,
db_metadata=None,
default=test_default
)
assert result == test_default
Loading