From 7b7121e81483628cba6847e7e24ced6d3edc36e3 Mon Sep 17 00:00:00 2001 From: nikazzio Date: Fri, 6 Jun 2025 14:11:50 +0200 Subject: [PATCH 1/5] feat: add config parameter to completion and chat parameters for customizable LLM configurations --- brevia/completions.py | 12 +++++++++--- brevia/query.py | 29 +++++++++++++++++++++-------- brevia/routers/completion_router.py | 1 + brevia/routers/qa_router.py | 3 ++- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/brevia/completions.py b/brevia/completions.py index a7514708..52f51b33 100644 --- a/brevia/completions.py +++ b/brevia/completions.py @@ -10,6 +10,7 @@ class CompletionParams(BaseModel): """ Q&A basic conversation chain params """ prompt: dict | None = None + config: dict | None = None def load_custom_prompt(prompt: dict | None): @@ -28,10 +29,15 @@ def simple_completion_chain( """ settings = get_settings() - llm_conf = settings.qa_completion_llm.copy() - comp_llm = load_chatmodel(llm_conf) verbose = settings.verbose_mode - # Create chain for follow-up question using chat history (if present) + + # Check if completion_llm config is provided in completion_params + if (completion_params.config + and completion_params.config.get('completion_llm')): + llm_conf = completion_params.config['completion_llm'].copy() + else: + llm_conf = settings.qa_completion_llm.copy() + comp_llm = load_chatmodel(llm_conf) completion_llm = LLMChain( llm=comp_llm, prompt=load_custom_prompt(completion_params.prompt), diff --git a/brevia/query.py b/brevia/query.py index 4c508d7a..aca1d7a1 100644 --- a/brevia/query.py +++ b/brevia/query.py @@ -67,6 +67,7 @@ class ChatParams(BaseModel): filter: dict[str, str | dict] | None = None source_docs: bool = False multiquery: bool = False + config: dict | None = None search_type: str = "similarity" score_threshold: float = 0.0 @@ -217,6 +218,8 @@ def conversation_rag_chain( - search_type (str): Type of search algorithm to use (def is 'similarity'). - score_threshold (float): Threshold for filtering documents based on relevance scores (default is 0.0). + - config (dict | None): Optional configuration dict that can contain + completion_llm and followup_llm configs to override defaults. answer_callbacks (list[BaseCallbackHandler] | None): List of callback handlers for the final LLM answer to enable streaming (default is None). @@ -233,10 +236,15 @@ def conversation_rag_chain( prompts = prompts if prompts else {} # Main LLM configuration - qa_llm_conf = collection.cmetadata.get( - 'qa_completion_llm', - dict(settings.qa_completion_llm).copy() - ) + # Check if completion_llm config is provided in chat_params + if (chat_params.config + and chat_params.config.get('completion_llm')): + qa_llm_conf = chat_params.config['completion_llm'].copy() + else: + qa_llm_conf = collection.cmetadata.get( + 'qa_completion_llm', + dict(settings.qa_completion_llm).copy() + ) qa_llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks qa_llm_conf['streaming'] = chat_params.streaming chatllm = load_chatmodel(qa_llm_conf) @@ -249,10 +257,15 @@ def conversation_rag_chain( ) # Chain to rewrite question with history - fup_llm_conf = collection.cmetadata.get( - 'qa_followup_llm', - dict(settings.qa_followup_llm).copy() - ) + # Check if followup_llm config is provided in chat_params + if (chat_params.config + and chat_params.config.get('followup_llm')): + fup_llm_conf = chat_params.config['followup_llm'].copy() + else: + fup_llm_conf = collection.cmetadata.get( + 'qa_followup_llm', + dict(settings.qa_followup_llm).copy() + ) fup_llm = load_chatmodel(fup_llm_conf) fup_chain = ( load_condense_prompt(prompts.get('condense')) diff --git a/brevia/routers/completion_router.py b/brevia/routers/completion_router.py index a0085a56..b3e120c2 100644 --- a/brevia/routers/completion_router.py +++ b/brevia/routers/completion_router.py @@ -14,6 +14,7 @@ class CompletionBody(CompletionParams): """ /completion request body """ text: str prompt: dict | None = None + config: dict | None = None token_data: bool = False diff --git a/brevia/routers/qa_router.py b/brevia/routers/qa_router.py index 0b358648..2a1b142b 100644 --- a/brevia/routers/qa_router.py +++ b/brevia/routers/qa_router.py @@ -1,8 +1,8 @@ """API endpoints for question answering and search""" import asyncio -from glom import glom from typing import Annotated from typing_extensions import Self +from glom import glom from pydantic import Field, model_validator from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.chains.base import Chain @@ -39,6 +39,7 @@ class ChatBody(ChatParams): collection: str | None = None chat_history: list = [] chat_lang: str | None = None + config: dict | None = None mode: str = Field(pattern='^(rag|conversation)$', default='rag') token_data: bool = False From f808b43001f716f59cd22d59f5240d9c6245f5cb Mon Sep 17 00:00:00 2001 From: nikazzio Date: Fri, 6 Jun 2025 16:26:01 +0200 Subject: [PATCH 2/5] feat: enhance conversation_chain to support customizable LLM configurations --- brevia/query.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/brevia/query.py b/brevia/query.py index aca1d7a1..f2d8e171 100644 --- a/brevia/query.py +++ b/brevia/query.py @@ -308,7 +308,13 @@ def conversation_chain( settings = get_settings() # Chain to rewrite question with history - fup_llm_conf = dict(settings.qa_followup_llm).copy() + # Check if followup_llm config is provided in chat_params + if (chat_params.config + and chat_params.config.get('followup_llm')): + fup_llm_conf = chat_params.config['followup_llm'].copy() + else: + fup_llm_conf = dict(settings.qa_followup_llm).copy() + fup_llm = load_chatmodel(fup_llm_conf) fup_chain = ( load_condense_prompt() @@ -316,7 +322,14 @@ def conversation_chain( | StrOutputParser() ) - llm_conf = dict(settings.qa_completion_llm).copy() + # Main LLM configuration + # Check if completion_llm config is provided in chat_params + if (chat_params.config + and chat_params.config.get('completion_llm')): + llm_conf = chat_params.config['completion_llm'].copy() + else: + llm_conf = dict(settings.qa_completion_llm).copy() + llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks llm_conf['streaming'] = chat_params.streaming From 613a78d9c7c6e19586ac2f83ee88503c444f2dcf Mon Sep 17 00:00:00 2001 From: nikazzio Date: Fri, 13 Jun 2025 14:05:00 +0200 Subject: [PATCH 3/5] feat: implement get_model_config function for flexible model configuration retrieval --- brevia/completions.py | 13 ++++++------- brevia/models.py | 27 ++++++++++++++++++++++++++ brevia/query.py | 34 +++++++++++++------------------- tests/test_completions.py | 41 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 28 deletions(-) diff --git a/brevia/completions.py b/brevia/completions.py index 52f51b33..a9527d57 100644 --- a/brevia/completions.py +++ b/brevia/completions.py @@ -3,7 +3,7 @@ from langchain.chains.llm import LLMChain from langchain_core.prompts.loading import load_prompt_from_config from pydantic import BaseModel -from brevia.models import load_chatmodel +from brevia.models import load_chatmodel, get_model_config from brevia.settings import get_settings @@ -31,12 +31,11 @@ def simple_completion_chain( settings = get_settings() verbose = settings.verbose_mode - # Check if completion_llm config is provided in completion_params - if (completion_params.config - and completion_params.config.get('completion_llm')): - llm_conf = completion_params.config['completion_llm'].copy() - else: - llm_conf = settings.qa_completion_llm.copy() + llm_conf = get_model_config( + 'qa_completion_llm', + user_config=completion_params.config, + default=settings.qa_completion_llm.copy() + ) comp_llm = load_chatmodel(llm_conf) completion_llm = LLMChain( llm=comp_llm, diff --git a/brevia/models.py b/brevia/models.py index df067c3a..dd92b97b 100644 --- a/brevia/models.py +++ b/brevia/models.py @@ -1,5 +1,6 @@ """Utilities to create langchain LLM and Chat Model instances.""" from abc import ABC, abstractmethod +from glom import glom from typing import Any from langchain.chat_models.base import init_chat_model from langchain_core.embeddings import Embeddings @@ -29,6 +30,32 @@ def get_token_ids(self, text: str) -> list[int]: }""" +def get_model_config( + key: str, + user_config: dict | None = None, + db_metadata: dict | None = None, + default: object = None, +) -> object: + """ + Retrieve a model configuration value by searching in user config, db metadata, + and settings, in order. Uses glom for safe nested lookup. + """ + # Check user config and db metadata + for source in [user_config, db_metadata]: + if source is not None: + value = glom(source, key, default=None) + if value is not None: + return value + + # Check settings only if it's a known key + settings = get_settings() + if hasattr(settings, key): + value = getattr(settings, key) + return value.copy() if hasattr(value, "copy") else value + + return default + + def load_llm(config: dict) -> BaseLLM: """Load langchain LLM, use Fake LLM in test mode""" if test_models_in_use(): diff --git a/brevia/query.py b/brevia/query.py index f2d8e171..0a2f0272 100644 --- a/brevia/query.py +++ b/brevia/query.py @@ -15,7 +15,7 @@ from pydantic import BaseModel from brevia.connection import connection_string from brevia.collections import single_collection_by_name -from brevia.models import load_chatmodel, load_embeddings +from brevia.models import load_chatmodel, load_embeddings, get_model_config from brevia.prompts import load_qa_prompt, load_condense_prompt from brevia.settings import get_settings from brevia.utilities.types import load_type @@ -235,16 +235,12 @@ def conversation_rag_chain( prompts = collection.cmetadata.get('prompts', {}) prompts = prompts if prompts else {} - # Main LLM configuration - # Check if completion_llm config is provided in chat_params - if (chat_params.config - and chat_params.config.get('completion_llm')): - qa_llm_conf = chat_params.config['completion_llm'].copy() - else: - qa_llm_conf = collection.cmetadata.get( - 'qa_completion_llm', - dict(settings.qa_completion_llm).copy() - ) + # Main LLM configuration using get_model_config + qa_llm_conf = get_model_config( + 'qa_completion_llm', + user_config=chat_params.config, + db_metadata=collection.cmetadata + ) qa_llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks qa_llm_conf['streaming'] = chat_params.streaming chatllm = load_chatmodel(qa_llm_conf) @@ -256,16 +252,12 @@ def conversation_rag_chain( llm=chatllm ) - # Chain to rewrite question with history - # Check if followup_llm config is provided in chat_params - if (chat_params.config - and chat_params.config.get('followup_llm')): - fup_llm_conf = chat_params.config['followup_llm'].copy() - else: - fup_llm_conf = collection.cmetadata.get( - 'qa_followup_llm', - dict(settings.qa_followup_llm).copy() - ) + # Chain to rewrite question with history using get_model_config + fup_llm_conf = get_model_config( + 'qa_followup_llm', + user_config=chat_params.config, + db_metadata=collection.cmetadata + ) fup_llm = load_chatmodel(fup_llm_conf) fup_chain = ( load_condense_prompt(prompts.get('condense')) diff --git a/tests/test_completions.py b/tests/test_completions.py index 13f515c4..cb11caf5 100644 --- a/tests/test_completions.py +++ b/tests/test_completions.py @@ -1,6 +1,7 @@ """Query module tests""" from langchain.chains.base import Chain from brevia.completions import simple_completion_chain, CompletionParams +from brevia.models import get_model_config fake_prompt = CompletionParams() @@ -16,3 +17,43 @@ def test_simple_completion_chain(): result = simple_completion_chain(fake_prompt) assert result is not None assert isinstance(result, Chain) + + +def test_get_model_config(): + """Test get_model_config functionality""" + # Test data using qa_completion_llm which exists in settings + test_key = "qa_completion_llm" + test_user_config = { + "qa_completion_llm": {"model": "from_user"} + } + test_db_config = { + "qa_completion_llm": {"model": "from_db"} + } + test_default = {"model": "default"} + + # Test user config priority + result = get_model_config( + test_key, + user_config=test_user_config, + db_metadata=test_db_config, + default=test_default + ) + assert result == {"model": "from_user"} + + # Test db config fallback + result = get_model_config( + test_key, + user_config=None, + db_metadata=test_db_config, + default=test_default + ) + assert result == {"model": "from_db"} + + # Test default fallback when key not in settings + result = get_model_config( + "non_existent_key", + user_config=None, + db_metadata=None, + default=test_default + ) + assert result == test_default From fec17b7b34e85fd67ab78b093c3339af11d6b643 Mon Sep 17 00:00:00 2001 From: nikazzio Date: Fri, 13 Jun 2025 15:56:25 +0200 Subject: [PATCH 4/5] feat: refactor conversation_chain to utilize get_model_config for LLM configurations --- brevia/query.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/brevia/query.py b/brevia/query.py index 0a2f0272..8d7254b5 100644 --- a/brevia/query.py +++ b/brevia/query.py @@ -301,11 +301,10 @@ def conversation_chain( # Chain to rewrite question with history # Check if followup_llm config is provided in chat_params - if (chat_params.config - and chat_params.config.get('followup_llm')): - fup_llm_conf = chat_params.config['followup_llm'].copy() - else: - fup_llm_conf = dict(settings.qa_followup_llm).copy() + fup_llm_conf = get_model_config( + 'qa_followup_llm', + user_config=chat_params.config + ) fup_llm = load_chatmodel(fup_llm_conf) fup_chain = ( @@ -316,11 +315,10 @@ def conversation_chain( # Main LLM configuration # Check if completion_llm config is provided in chat_params - if (chat_params.config - and chat_params.config.get('completion_llm')): - llm_conf = chat_params.config['completion_llm'].copy() - else: - llm_conf = dict(settings.qa_completion_llm).copy() + llm_conf = get_model_config( + 'qa_completion_llm', + user_config=chat_params.config + ) llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks llm_conf['streaming'] = chat_params.streaming From bf5c43dc40d3e44e177300e6a6157391bcd279b7 Mon Sep 17 00:00:00 2001 From: nikazzio Date: Fri, 13 Jun 2025 15:58:23 +0200 Subject: [PATCH 5/5] feat: remove unnecessary settings retrieval in conversation_chain function --- brevia/query.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/brevia/query.py b/brevia/query.py index 8d7254b5..8f1f9172 100644 --- a/brevia/query.py +++ b/brevia/query.py @@ -297,8 +297,6 @@ def conversation_chain( collection or dataset. """ - settings = get_settings() - # Chain to rewrite question with history # Check if followup_llm config is provided in chat_params fup_llm_conf = get_model_config(