diff --git a/brevia/completions.py b/brevia/completions.py index a751470..a9527d5 100644 --- a/brevia/completions.py +++ b/brevia/completions.py @@ -3,13 +3,14 @@ from langchain.chains.llm import LLMChain from langchain_core.prompts.loading import load_prompt_from_config from pydantic import BaseModel -from brevia.models import load_chatmodel +from brevia.models import load_chatmodel, get_model_config from brevia.settings import get_settings class CompletionParams(BaseModel): """ Q&A basic conversation chain params """ prompt: dict | None = None + config: dict | None = None def load_custom_prompt(prompt: dict | None): @@ -28,10 +29,14 @@ def simple_completion_chain( """ settings = get_settings() - llm_conf = settings.qa_completion_llm.copy() - comp_llm = load_chatmodel(llm_conf) verbose = settings.verbose_mode - # Create chain for follow-up question using chat history (if present) + + llm_conf = get_model_config( + 'qa_completion_llm', + user_config=completion_params.config, + default=settings.qa_completion_llm.copy() + ) + comp_llm = load_chatmodel(llm_conf) completion_llm = LLMChain( llm=comp_llm, prompt=load_custom_prompt(completion_params.prompt), diff --git a/brevia/models.py b/brevia/models.py index df067c3..dd92b97 100644 --- a/brevia/models.py +++ b/brevia/models.py @@ -1,5 +1,6 @@ """Utilities to create langchain LLM and Chat Model instances.""" from abc import ABC, abstractmethod +from glom import glom from typing import Any from langchain.chat_models.base import init_chat_model from langchain_core.embeddings import Embeddings @@ -29,6 +30,32 @@ def get_token_ids(self, text: str) -> list[int]: }""" +def get_model_config( + key: str, + user_config: dict | None = None, + db_metadata: dict | None = None, + default: object = None, +) -> object: + """ + Retrieve a model configuration value by searching in user config, db metadata, + and settings, in order. Uses glom for safe nested lookup. + """ + # Check user config and db metadata + for source in [user_config, db_metadata]: + if source is not None: + value = glom(source, key, default=None) + if value is not None: + return value + + # Check settings only if it's a known key + settings = get_settings() + if hasattr(settings, key): + value = getattr(settings, key) + return value.copy() if hasattr(value, "copy") else value + + return default + + def load_llm(config: dict) -> BaseLLM: """Load langchain LLM, use Fake LLM in test mode""" if test_models_in_use(): diff --git a/brevia/query.py b/brevia/query.py index 4c508d7..8f1f917 100644 --- a/brevia/query.py +++ b/brevia/query.py @@ -15,7 +15,7 @@ from pydantic import BaseModel from brevia.connection import connection_string from brevia.collections import single_collection_by_name -from brevia.models import load_chatmodel, load_embeddings +from brevia.models import load_chatmodel, load_embeddings, get_model_config from brevia.prompts import load_qa_prompt, load_condense_prompt from brevia.settings import get_settings from brevia.utilities.types import load_type @@ -67,6 +67,7 @@ class ChatParams(BaseModel): filter: dict[str, str | dict] | None = None source_docs: bool = False multiquery: bool = False + config: dict | None = None search_type: str = "similarity" score_threshold: float = 0.0 @@ -217,6 +218,8 @@ def conversation_rag_chain( - search_type (str): Type of search algorithm to use (def is 'similarity'). - score_threshold (float): Threshold for filtering documents based on relevance scores (default is 0.0). + - config (dict | None): Optional configuration dict that can contain + completion_llm and followup_llm configs to override defaults. answer_callbacks (list[BaseCallbackHandler] | None): List of callback handlers for the final LLM answer to enable streaming (default is None). @@ -232,10 +235,11 @@ def conversation_rag_chain( prompts = collection.cmetadata.get('prompts', {}) prompts = prompts if prompts else {} - # Main LLM configuration - qa_llm_conf = collection.cmetadata.get( + # Main LLM configuration using get_model_config + qa_llm_conf = get_model_config( 'qa_completion_llm', - dict(settings.qa_completion_llm).copy() + user_config=chat_params.config, + db_metadata=collection.cmetadata ) qa_llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks qa_llm_conf['streaming'] = chat_params.streaming @@ -248,10 +252,11 @@ def conversation_rag_chain( llm=chatllm ) - # Chain to rewrite question with history - fup_llm_conf = collection.cmetadata.get( + # Chain to rewrite question with history using get_model_config + fup_llm_conf = get_model_config( 'qa_followup_llm', - dict(settings.qa_followup_llm).copy() + user_config=chat_params.config, + db_metadata=collection.cmetadata ) fup_llm = load_chatmodel(fup_llm_conf) fup_chain = ( @@ -292,10 +297,13 @@ def conversation_chain( collection or dataset. """ - settings = get_settings() - # Chain to rewrite question with history - fup_llm_conf = dict(settings.qa_followup_llm).copy() + # Check if followup_llm config is provided in chat_params + fup_llm_conf = get_model_config( + 'qa_followup_llm', + user_config=chat_params.config + ) + fup_llm = load_chatmodel(fup_llm_conf) fup_chain = ( load_condense_prompt() @@ -303,7 +311,13 @@ def conversation_chain( | StrOutputParser() ) - llm_conf = dict(settings.qa_completion_llm).copy() + # Main LLM configuration + # Check if completion_llm config is provided in chat_params + llm_conf = get_model_config( + 'qa_completion_llm', + user_config=chat_params.config + ) + llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks llm_conf['streaming'] = chat_params.streaming diff --git a/brevia/routers/completion_router.py b/brevia/routers/completion_router.py index a0085a5..b3e120c 100644 --- a/brevia/routers/completion_router.py +++ b/brevia/routers/completion_router.py @@ -14,6 +14,7 @@ class CompletionBody(CompletionParams): """ /completion request body """ text: str prompt: dict | None = None + config: dict | None = None token_data: bool = False diff --git a/brevia/routers/qa_router.py b/brevia/routers/qa_router.py index 0b35864..2a1b142 100644 --- a/brevia/routers/qa_router.py +++ b/brevia/routers/qa_router.py @@ -1,8 +1,8 @@ """API endpoints for question answering and search""" import asyncio -from glom import glom from typing import Annotated from typing_extensions import Self +from glom import glom from pydantic import Field, model_validator from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.chains.base import Chain @@ -39,6 +39,7 @@ class ChatBody(ChatParams): collection: str | None = None chat_history: list = [] chat_lang: str | None = None + config: dict | None = None mode: str = Field(pattern='^(rag|conversation)$', default='rag') token_data: bool = False diff --git a/tests/test_completions.py b/tests/test_completions.py index 13f515c..cb11caf 100644 --- a/tests/test_completions.py +++ b/tests/test_completions.py @@ -1,6 +1,7 @@ """Query module tests""" from langchain.chains.base import Chain from brevia.completions import simple_completion_chain, CompletionParams +from brevia.models import get_model_config fake_prompt = CompletionParams() @@ -16,3 +17,43 @@ def test_simple_completion_chain(): result = simple_completion_chain(fake_prompt) assert result is not None assert isinstance(result, Chain) + + +def test_get_model_config(): + """Test get_model_config functionality""" + # Test data using qa_completion_llm which exists in settings + test_key = "qa_completion_llm" + test_user_config = { + "qa_completion_llm": {"model": "from_user"} + } + test_db_config = { + "qa_completion_llm": {"model": "from_db"} + } + test_default = {"model": "default"} + + # Test user config priority + result = get_model_config( + test_key, + user_config=test_user_config, + db_metadata=test_db_config, + default=test_default + ) + assert result == {"model": "from_user"} + + # Test db config fallback + result = get_model_config( + test_key, + user_config=None, + db_metadata=test_db_config, + default=test_default + ) + assert result == {"model": "from_db"} + + # Test default fallback when key not in settings + result = get_model_config( + "non_existent_key", + user_config=None, + db_metadata=None, + default=test_default + ) + assert result == test_default