Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions brevia/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class CompletionParams(BaseModel):
""" Q&A basic conversation chain params """
prompt: dict | None = None
config: dict | None = None


def load_custom_prompt(prompt: dict | None):
Expand All @@ -28,10 +29,15 @@ def simple_completion_chain(
"""

settings = get_settings()
llm_conf = settings.qa_completion_llm.copy()
comp_llm = load_chatmodel(llm_conf)
verbose = settings.verbose_mode
# Create chain for follow-up question using chat history (if present)

# Check if completion_llm config is provided in completion_params
if (completion_params.config
and completion_params.config.get('completion_llm')):
llm_conf = completion_params.config['completion_llm'].copy()
else:
llm_conf = settings.qa_completion_llm.copy()
comp_llm = load_chatmodel(llm_conf)
completion_llm = LLMChain(
llm=comp_llm,
prompt=load_custom_prompt(completion_params.prompt),
Expand Down
46 changes: 36 additions & 10 deletions brevia/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class ChatParams(BaseModel):
filter: dict[str, str | dict] | None = None
source_docs: bool = False
multiquery: bool = False
config: dict | None = None
search_type: str = "similarity"
score_threshold: float = 0.0

Expand Down Expand Up @@ -217,6 +218,8 @@ def conversation_rag_chain(
- search_type (str): Type of search algorithm to use (def is 'similarity').
- score_threshold (float): Threshold for filtering documents based on
relevance scores (default is 0.0).
- config (dict | None): Optional configuration dict that can contain
completion_llm and followup_llm configs to override defaults.
answer_callbacks (list[BaseCallbackHandler] | None): List of callback handlers
for the final LLM answer to enable streaming (default is None).

Expand All @@ -233,10 +236,15 @@ def conversation_rag_chain(
prompts = prompts if prompts else {}

# Main LLM configuration
qa_llm_conf = collection.cmetadata.get(
'qa_completion_llm',
dict(settings.qa_completion_llm).copy()
)
# Check if completion_llm config is provided in chat_params
if (chat_params.config
Copy link

Copilot AI Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for retrieving config keys is repeated in multiple chain functions; consider refactoring this into a helper function to improve maintainability.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikazzio as Copilot suggests we could have a helper function in brevia.models or brevia.settings like this:

from glom import glom
from brevia.settings import get_settings
# ...existing code...

def get_model_config(
    key: str,
    user_config: dict | None = None,
    db_metadata: dict | None = None,
    default: object = None,
) -> object:
    """
    Retrieve a model configuration value by searching in user config, db metadata, and settings, in order.
    Uses glom for safe nested lookup.
    """
    for source in [user_config, db_metadata]:
        if source is not None:
            value = glom(source, key)
            if value is not None:
                return value
    value = glom(get_settings(), key)
    if value is not None:
        # return a copy from settings to avoid mutations
        return value.copy() if hasattr(value, "copy") else value
    return default

we could then add a test case for this function and avoid testing the actual configuration cases

and chat_params.config.get('completion_llm')):
qa_llm_conf = chat_params.config['completion_llm'].copy()
else:
qa_llm_conf = collection.cmetadata.get(
'qa_completion_llm',
dict(settings.qa_completion_llm).copy()
)
qa_llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks
qa_llm_conf['streaming'] = chat_params.streaming
chatllm = load_chatmodel(qa_llm_conf)
Expand All @@ -249,10 +257,15 @@ def conversation_rag_chain(
)

# Chain to rewrite question with history
fup_llm_conf = collection.cmetadata.get(
'qa_followup_llm',
dict(settings.qa_followup_llm).copy()
)
# Check if followup_llm config is provided in chat_params
if (chat_params.config
and chat_params.config.get('followup_llm')):
fup_llm_conf = chat_params.config['followup_llm'].copy()
else:
fup_llm_conf = collection.cmetadata.get(
'qa_followup_llm',
dict(settings.qa_followup_llm).copy()
)
fup_llm = load_chatmodel(fup_llm_conf)
fup_chain = (
load_condense_prompt(prompts.get('condense'))
Expand Down Expand Up @@ -295,15 +308,28 @@ def conversation_chain(
settings = get_settings()

# Chain to rewrite question with history
fup_llm_conf = dict(settings.qa_followup_llm).copy()
# Check if followup_llm config is provided in chat_params
if (chat_params.config
and chat_params.config.get('followup_llm')):
fup_llm_conf = chat_params.config['followup_llm'].copy()
else:
fup_llm_conf = dict(settings.qa_followup_llm).copy()

fup_llm = load_chatmodel(fup_llm_conf)
fup_chain = (
load_condense_prompt()
| fup_llm
| StrOutputParser()
)

llm_conf = dict(settings.qa_completion_llm).copy()
# Main LLM configuration
# Check if completion_llm config is provided in chat_params
if (chat_params.config
and chat_params.config.get('completion_llm')):
llm_conf = chat_params.config['completion_llm'].copy()
else:
llm_conf = dict(settings.qa_completion_llm).copy()

llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks
llm_conf['streaming'] = chat_params.streaming

Expand Down
1 change: 1 addition & 0 deletions brevia/routers/completion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class CompletionBody(CompletionParams):
""" /completion request body """
text: str
prompt: dict | None = None
config: dict | None = None
token_data: bool = False


Expand Down
3 changes: 2 additions & 1 deletion brevia/routers/qa_router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""API endpoints for question answering and search"""
import asyncio
from glom import glom
from typing import Annotated
from typing_extensions import Self
from glom import glom
from pydantic import Field, model_validator
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains.base import Chain
Expand Down Expand Up @@ -39,6 +39,7 @@ class ChatBody(ChatParams):
collection: str | None = None
chat_history: list = []
chat_lang: str | None = None
config: dict | None = None
mode: str = Field(pattern='^(rag|conversation)$', default='rag')
token_data: bool = False

Expand Down