diff --git a/brevia/callback.py b/brevia/callback.py index eb2b7d3..5245a71 100644 --- a/brevia/callback.py +++ b/brevia/callback.py @@ -4,6 +4,7 @@ import asyncio import logging import json +from glom import glom from langchain_community.callbacks import OpenAICallbackHandler, get_openai_callback from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.documents import Document @@ -81,7 +82,11 @@ async def on_chain_end( ) -> None: """Run when chain ends running.""" if parent_run_id is None: - self.answer = outputs.get('answer') + self.answer = glom( + outputs, + 'answer', + default=glom(outputs, 'content', default='') + ) self.chain_ended.set() async def wait_conversation_done(self): diff --git a/brevia/chat_history.py b/brevia/chat_history.py index a5fc05e..c7def6d 100644 --- a/brevia/chat_history.py +++ b/brevia/chat_history.py @@ -115,12 +115,14 @@ def add_history( return None with Session(db_connection()) as session: - collection_store = CollectionStore.get_by_name(session, collection) - if not collection_store: - raise ValueError("Collection not found") + collection_uuid = None # Default empty UUID + if collection: + collection_store = CollectionStore.get_by_name(session, collection) + collection_uuid = collection_store.uuid if collection_store else None + chat_history_store = ChatHistoryStore( session_id=session_id, - collection_id=collection_store.uuid, + collection_id=collection_uuid, question=question, answer=answer, cmetadata=metadata, diff --git a/brevia/models.py b/brevia/models.py index 0deae0c..df067c3 100644 --- a/brevia/models.py +++ b/brevia/models.py @@ -22,8 +22,11 @@ def get_token_ids(self, text: str) -> list[int]: return [10] * 10 -LOREM_IPSUM = """Lorem ipsum dolor sit amet, consectetur adipisici elit, -sed eiusmod tempor incidunt ut labore et dolore magna aliqua.""" +LOREM_IPSUM = """{ + "question": "What is lorem ipsum?", + "answer": "Lorem ipsum dolor sit amet, consectetur adipisici elit, + sed eiusmod tempor incidunt ut labore et dolore magna aliqua." +}""" def load_llm(config: dict) -> BaseLLM: diff --git a/brevia/postman/Brevia API.postman_collection.json b/brevia/postman/Brevia API.postman_collection.json index 7ae5a76..2252e86 100644 --- a/brevia/postman/Brevia API.postman_collection.json +++ b/brevia/postman/Brevia API.postman_collection.json @@ -37,7 +37,49 @@ ], "body": { "mode": "raw", - "raw": "{\n \"question\" : \"{{query}}\",\n \"collection\" : \"{{collection}}\"\n}" + "raw": "{\n \"question\" : \"{{query}}\",\n \"mode\": \"rag\",\n \"collection\" : \"{{collection}}\"\n}" + }, + "url": { + "raw": "{{baseUrl}}/chat", + "host": [ + "{{baseUrl}}" + ], + "path": [ + "chat" + ] + } + }, + "response": [] + }, + { + "name": "chat - conversation, no RAG", + "request": { + "auth": { + "type": "bearer", + "bearer": [ + { + "key": "token", + "value": "{{access_token}}", + "type": "string" + } + ] + }, + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/json", + "type": "text" + }, + { + "key": "X-Chat-Session", + "value": "{{session_id}}", + "type": "text" + } + ], + "body": { + "mode": "raw", + "raw": "{\n \"question\" : \"{{query}}\",\n \"mode\": \"conversation\",\n \"streaming\": false,\n \"token_data\": true\n}" }, "url": { "raw": "{{baseUrl}}/chat", diff --git a/brevia/query.py b/brevia/query.py index 9c6772a..4c508d7 100644 --- a/brevia/query.py +++ b/brevia/query.py @@ -5,6 +5,7 @@ from langchain_community.vectorstores.pgembedding import CollectionStore from langchain_community.vectorstores.pgvector import DistanceStrategy, PGVector from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_core.retrievers import BaseRetriever @@ -173,7 +174,7 @@ def create_conversation_retriever( search_kwargs = chat_params.get_search_kwargs() retriever_conf = collection.cmetadata.get( 'qa_retriever', - get_settings().qa_retriever.copy() + dict(get_settings().qa_retriever).copy() ) if not retriever_conf: return create_default_retriever( @@ -189,13 +190,13 @@ def create_conversation_retriever( document_search, search_kwargs, retriever_conf) -def conversation_chain( +def conversation_rag_chain( collection: CollectionStore, chat_params: ChatParams, answer_callbacks: list[BaseCallbackHandler] | None = None, ) -> Chain: """ - Create and return a conversation chain for Q&A with embedded dataset knowledge. + Create and return a conversation chain for Q&A with embedded dataset knowledge.(RAG) Args: collection (CollectionStore): The collection store item containing the dataset. @@ -234,7 +235,7 @@ def conversation_chain( # Main LLM configuration qa_llm_conf = collection.cmetadata.get( 'qa_completion_llm', - settings.qa_completion_llm.copy() + dict(settings.qa_completion_llm).copy() ) qa_llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks qa_llm_conf['streaming'] = chat_params.streaming @@ -250,7 +251,7 @@ def conversation_chain( # Chain to rewrite question with history fup_llm_conf = collection.cmetadata.get( 'qa_followup_llm', - settings.qa_followup_llm.copy() + dict(settings.qa_followup_llm).copy() ) fup_llm = load_chatmodel(fup_llm_conf) fup_chain = ( @@ -279,3 +280,43 @@ def conversation_chain( ) | retrivial_chain ) + + +def conversation_chain( + chat_params: ChatParams, + answer_callbacks: list[BaseCallbackHandler] | None = None, +) -> Chain: + """ + Create a simple conversation chain for conversation tasks without a collection. + This chain is used for general chat interactions that do not involve a specific + collection or dataset. + """ + + settings = get_settings() + + # Chain to rewrite question with history + fup_llm_conf = dict(settings.qa_followup_llm).copy() + fup_llm = load_chatmodel(fup_llm_conf) + fup_chain = ( + load_condense_prompt() + | fup_llm + | StrOutputParser() + ) + + llm_conf = dict(settings.qa_completion_llm).copy() + llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks + llm_conf['streaming'] = chat_params.streaming + + prompt = PromptTemplate( + input_variables=["question"], + template="\n{question}", + ) + llm = load_chatmodel(llm_conf) + chain = prompt | llm + + return ( + RunnablePassthrough.assign( + question=fup_chain + ) + | chain + ) diff --git a/brevia/routers/qa_router.py b/brevia/routers/qa_router.py index 417e96e..0b35864 100644 --- a/brevia/routers/qa_router.py +++ b/brevia/routers/qa_router.py @@ -1,9 +1,13 @@ """API endpoints for question answering and search""" -from typing import Annotated import asyncio +from glom import glom +from typing import Annotated +from typing_extensions import Self +from pydantic import Field, model_validator from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.chains.base import Chain from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages.ai import AIMessage from fastapi import APIRouter, Header from fastapi.responses import StreamingResponse from brevia import chat_history @@ -17,7 +21,13 @@ token_usage, TokensCallbackHandler, ) -from brevia.query import SearchQuery, ChatParams, conversation_chain, search_vector_qa +from brevia.query import ( + SearchQuery, + ChatParams, + conversation_chain, + conversation_rag_chain, + search_vector_qa, +) from brevia.models import test_models_in_use router = APIRouter() @@ -26,11 +36,19 @@ class ChatBody(ChatParams): """ /chat request body """ question: str - collection: str + collection: str | None = None chat_history: list = [] chat_lang: str | None = None + mode: str = Field(pattern='^(rag|conversation)$', default='rag') token_data: bool = False + @model_validator(mode='after') + def check_collection(self) -> Self: + """Validate collection and mode""" + if not self.collection and self.mode == 'rag': + raise ValueError('Collection required for rag mode') + return self + @router.post('/prompt', dependencies=get_dependencies(), deprecated=True, tags=['Chat']) @router.post('/chat', dependencies=get_dependencies(), tags=['Chat']) @@ -38,20 +56,40 @@ async def chat_action( chat_body: ChatBody, x_chat_session: Annotated[str | None, Header()] = None, ): - """ /chat endpoint, ask chatbot about a collection of documents """ - collection = check_collection_name(chat_body.collection) - if not collection.cmetadata: - collection.cmetadata = {} - lang = chat_language(chat_body=chat_body, cmetadata=collection.cmetadata) - + """ + /chat endpoint, ask chatbot about a collection of documents to perform a rag chat. + If collection is not provided, it will use a simple completion chain. + """ + # Check if collection is provided and valid + collection = None + if chat_body.collection: + collection = check_collection_name(chat_body.collection) + if not collection.cmetadata: + collection.cmetadata = {} + + lang = chat_language( + chat_body=chat_body, + cmetadata=collection.cmetadata if collection else {} + ) conversation_handler = ConversationCallbackHandler() stream_handler = AsyncIteratorCallbackHandler() - chain = conversation_chain( - collection=collection, - chat_params=ChatParams(**chat_body.model_dump()), - answer_callbacks=[stream_handler] if chat_body.streaming else [], - ) - embeddings = collection.cmetadata.get('embeddings', None) + + # Select chain based on chat_body.mode + if chat_body.mode == 'rag': + # RAG-based conversation chain using collection context + chain = conversation_rag_chain( + collection=collection, + chat_params=ChatParams(**chat_body.model_dump()), + answer_callbacks=[stream_handler] if chat_body.streaming else [], + ) + embeddings = collection.cmetadata.get('embeddings', None) + elif chat_body.mode == 'conversation': + # Test mode - currently same as simple conversation + chain = conversation_chain( + chat_params=ChatParams(**chat_body.model_dump()), + answer_callbacks=[stream_handler] if chat_body.streaming else [], + ) + embeddings = None with token_usage_callback() as token_callback: if not chat_body.streaming or test_models_in_use(): @@ -157,13 +195,14 @@ async def run_chain( def chat_result( - result: dict, + result: dict | AIMessage, callb: TokensCallbackHandler, chat_body: ChatBody, x_chat_session: str | None = None, ) -> dict: """ Handle chat result: save chat history and return answer """ - answer = result['answer'].strip(" \n") + answer = glom(result, 'answer', default=glom(result, 'content', default='')) + answer = str(answer).strip(" \n") chat_history_id = None if not chat_body.streaming: @@ -176,9 +215,11 @@ def chat_result( ) chat_history_id = None if chat_hist is None else str(chat_hist.uuid) + context = result['context'] if 'context' in result else None + return { 'bot': answer, - 'docs': None if not chat_body.source_docs else result['context'], + 'docs': None if not chat_body.source_docs else context, 'chat_history_id': chat_history_id, 'token_data': None if not chat_body.token_data else token_usage(callb) } diff --git a/docs/chat_search.md b/docs/chat_search.md index 7d4b4b3..ef65bad 100644 --- a/docs/chat_search.md +++ b/docs/chat_search.md @@ -4,7 +4,7 @@ This section explores the advanced chat functionalities that enable natural and Brevia provides two distinct endpoints for managing conversations: -`/chat`: This endpoint is designed to initiate fluid and natural conversations with the language model. It integrates a conversational memory and chat history system, allowing you to build on previous interactions and create a more engaging experience. +`/chat`: This endpoint is designed to initiate fluid and natural conversations with the language model. It integrates a conversational memory and chat history system, allowing you to build on previous interactions and create a more engaging experience. It supports both Retrieval-Augmented Generation (RAG) mode, where the model retrieves relevant documents from a specified collection to answer questions, and pure conversational mode, where the model relies solely on chat history and its own knowledge. `/completion`: This endpoint is ideal for executing single commands and requests. It provides quick and concise responses without the need for conversational context, making it perfect for launching specific tasks or obtaining immediate information. @@ -49,24 +49,29 @@ Initiates a natural conversation with the model. **Payloads**: -`question`: The query you want to ask the model. -`collection`: The collection of documents to search for relevant information. +- `question`: The query you want to ask the model. +- `collection`: The collection of documents to search for relevant information (mandatory if `mode` is set to `"rag"`). +- `mode` (optional): Specifies the chat mode. + - `"rag"`: Retrieval-Augmented Generation mode. The model answers using information retrieved from the specified collection. + - `"conversation"`: Pure conversational mode. The model answers based only on the chat history and its own knowledge, without retrieving documents. + - If not specified, the default is `"rag"`. -```JSON +```json { "question": "{{query}}", - "collection": "{{collection}}" + "collection": "{{collection}}", + "mode": "rag" } ``` **Optional Parameters**: -`chat_history`: An array of previous questions and answers to provide context for the current query. +- `chat_history`: An array of previous questions and answers to provide context for the current query. -```JSON +```json { "question": "{{query}}", - "collection": "{{collection}}", + "mode": "conversation", "chat_history": [ { "query": "what is artificial intelligence?", @@ -82,33 +87,31 @@ Initiates a natural conversation with the model. **Additional Optional Parameters**: -`docs_num`: The number of documents to retrieve for context. If not specified, the default from settings or collection metadata is used. - -`streaming`: A boolean flag to enable or disable streaming responses. Default is `False`. - -`distance_strategy_name`: The name of the distance strategy to use for vector similarity. Options include `euclidean`, `cosine`, and `max`. Default is `cosine`. - -`filter`: An optional dictionary of metadata to use as a filter for document retrieval. +For `conversation` and `RAG` modes: -`source_docs`: A boolean flag to specify if the retrieved source documents should be included in the response. Default is `False`. +- `streaming`: A boolean flag to enable or disable streaming responses. Default is `False`. -`multiquery`: A boolean flag indicating whether multiple queries should be executed for retrieval. Default is `False`. +For `rag` mode only: -`search_type`: The type of search algorithm to use. Options include: - -- `similarity`: Standard similarity search. -- `similarity_score_threshold`: Similarity search with a score threshold. -- `mmr`: Maximal Marginal Relevance search. -Default is `similarity`. - -`score_threshold`: A numeric threshold for filtering documents based on relevance scores. Default is `0.0` (This applies only when `search_type` is set to `similarity_score_threshold`.). +- `docs_num`: The number of documents to retrieve for context. If not specified, the default from settings or collection metadata is used. +- `distance_strategy_name`: The name of the distance strategy to use for vector similarity. Options include `euclidean`, `cosine`, and `max`. Default is `cosine`. +- `filter`: An optional dictionary of metadata to use as a filter for document retrieval. +- `source_docs`: A boolean flag to specify if the retrieved source documents should be included in the response. Default is `False`. +- `multiquery`: A boolean flag indicating whether multiple queries should be executed for retrieval. Default is `False`. +- `search_type`: The type of search algorithm to use. Options include: + - `similarity`: Standard similarity search. + - `similarity_score_threshold`: Similarity search with a score threshold. + - `mmr`: Maximal Marginal Relevance search. + Default is `similarity`. +- `score_threshold`: A numeric threshold for filtering documents based on relevance scores. Default is `0.0` (applies only when `search_type` is set to `similarity_score_threshold`). **Example Payload**: -```JSON +```json { "question": "What is the capital of France?", "collection": "geography", + "mode": "rag", "chat_history": [ { "query": "What is the largest country in Europe?", @@ -125,6 +128,12 @@ Default is `similarity`. } ``` +**Notes:** + +- If `mode` is not specified, the default behavior is `"rag"`. +- In `"rag"` mode, the model uses both the provided collection and chat history for context. +- In `"conversation"` mode, the model ignores the collection and relies solely on chat history and its own knowledge. + ### POST `/completion` Executes a single command or request without conversational context. diff --git a/docs/endpoints_overview.md b/docs/endpoints_overview.md index ed0be29..80ee171 100644 --- a/docs/endpoints_overview.md +++ b/docs/endpoints_overview.md @@ -8,24 +8,30 @@ Initiates a natural conversation with the model. **Payloads**: -`question`: The query you want to ask the model. -`collection`: The collection of documents to search for relevant information. - -```JSON +- `question`: The query you want to ask the model. +- `collection`: The collection of documents to search for relevant information. Mandatory if `mode` is set to `"rag"`. +- `mode` (optional): Specifies the chat mode. + - `"rag"`: Retrieval-Augmented Generation mode. The model answers using information retrieved from the specified collection. + - `"conversation"`: Pure conversational mode. The model answers based only on the chat history and its own knowledge, without retrieving documents. + - If not specified, the default is `"rag"`. + +```json { "question": "{{query}}", - "collection": "{{collection}}" + "collection": "{{collection}}", + "mode": "rag" } ``` **Optional Parameters**: -`chat_history`: An array of previous questions and answers to provide context for the current query. +- `chat_history`: An array of previous questions and answers to provide context for the current query. -```JSON +```json { "question": "{{query}}", "collection": "{{collection}}", + "mode": "conversation", "chat_history": [ { "query": "what is artificial intelligence?", @@ -39,10 +45,31 @@ Initiates a natural conversation with the model. } ``` -`source_docs`: Set to true to return source documents. -`docs_num`: Specify the number of documents to return. -`token_data`: Set to true to return token-level data like part-of-speech tags. -`multiquery`: Set to true to use MultiQueryRetriever from langchain. +**Additional Optional Parameters**: + +For `conversation` and `RAG` modes: + +- `streaming`: A boolean flag to enable or disable streaming responses. Default is `False`. + +For `rag` mode only: + +- `docs_num`: The number of documents to retrieve for context. If not specified, the default from settings or collection metadata is used. +- `distance_strategy_name`: The name of the distance strategy to use for vector similarity. Options include `euclidean`, `cosine`, and `max`. Default is `cosine`. +- `filter`: An optional dictionary of metadata to use as a filter for document retrieval. +- `source_docs`: A boolean flag to specify if the retrieved source documents should be included in the response. Default is `False`. +- `multiquery`: A boolean flag indicating whether multiple queries should be executed for retrieval. Default is `False`. +- `search_type`: The type of search algorithm to use. Options include: + - `similarity`: Standard similarity search. + - `similarity_score_threshold`: Similarity search with a score threshold. + - `mmr`: Maximal Marginal Relevance search. + Default is `similarity`. +- `score_threshold`: A numeric threshold for filtering documents based on relevance scores. Default is `0.0` (applies only when `search_type` is set to `similarity_score_threshold`). + +**Notes:** + +- If `mode` is not specified, the default behavior is `"rag"`. +- In `"rag"` mode, the model uses both the provided collection and chat history for context. +- In `"conversation"` mode, the model ignores the collection and relies solely on chat history and its own knowledge. ### POST `/completion` diff --git a/poetry.lock b/poetry.lock index cb8d421..5eb3166 100644 --- a/poetry.lock +++ b/poetry.lock @@ -281,6 +281,18 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "boltons" +version = "25.0.0" +description = "When they're not builtins, they're boltons." +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "boltons-25.0.0-py3-none-any.whl", hash = "sha256:dc9fb38bf28985715497d1b54d00b62ea866eca3938938ea9043e254a3a6ca62"}, + {file = "boltons-25.0.0.tar.gz", hash = "sha256:e110fbdc30b7b9868cb604e3f71d4722dd8f4dcb4a5ddd06028ba8f1ab0b5ace"}, +] + [[package]] name = "bs4" version = "0.0.1" @@ -699,6 +711,21 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "face" +version = "24.0.0" +description = "A command-line application framework (and CLI parser). Friendly for users, full-featured for developers." +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "face-24.0.0-py3-none-any.whl", hash = "sha256:0e2c17b426fa4639a4e77d1de9580f74a98f4869ba4c7c8c175b810611622cd3"}, + {file = "face-24.0.0.tar.gz", hash = "sha256:611e29a01ac5970f0077f9c577e746d48c082588b411b33a0dd55c4d872949f6"}, +] + +[package.dependencies] +boltons = ">=20.0.0" + [[package]] name = "fastapi" version = "0.115.12" @@ -943,6 +970,27 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "glom" +version = "24.11.0" +description = "A declarative object transformer and formatter, for conglomerating nested data." +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "glom-24.11.0-py3-none-any.whl", hash = "sha256:991db7fcb4bfa9687010aa519b7b541bbe21111e70e58fdd2d7e34bbaa2c1fbd"}, + {file = "glom-24.11.0.tar.gz", hash = "sha256:4325f96759a912044af7b6c6bd0dba44ad8c1eb6038aab057329661d2021bb27"}, +] + +[package.dependencies] +attrs = "*" +boltons = ">=19.3.0" +face = ">=20.1.1" + +[package.extras] +toml = ["tomli ; python_version < \"3.11\""] +yaml = ["PyYAML"] + [[package]] name = "greenlet" version = "3.1.1" @@ -3948,4 +3996,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "ce226e5b264b59d59d08f95eff41ba2da4d6efc5193acd31e6178baf7bddcc69" +content-hash = "e3013f472ac94245bd324b4c9390ae2ffabac32aa9566f97adaaf12af3149b70" diff --git a/pyproject.toml b/pyproject.toml index b1c1405..03ee1f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ repository = "https://github.com/brevia-ai/brevia" pgvector = "^0.4.0" pypdf = "^5.4.0" lxml = "^5.3.1" + glom = "^24.11.0" [tool.poetry.dependencies.uvicorn] version = "^0.34.2" diff --git a/tests/routers/test_qa_router.py b/tests/routers/test_qa_router.py index 73de27e..ae8de72 100644 --- a/tests/routers/test_qa_router.py +++ b/tests/routers/test_qa_router.py @@ -95,7 +95,11 @@ def test_search_filter(): def test_chat_language(): """Test chat_language method""" - chat_body = ChatBody(question='', collection='', chat_lang='Klingon') + chat_body = ChatBody( + question='question', + collection='collection', + chat_lang='Klingon' + ) lang = chat_language(chat_body=chat_body, cmetadata={}) assert lang == 'Klingon' @@ -116,3 +120,35 @@ def test_extract_content_score(): """Test extract_content_score method""" result = extract_content_score(data_list={'error': 'big problems!'}) assert result == {'error': 'big problems!'} + + +def test_chat_invalid_mode(): + """Test POST /chat with invalid mode""" + create_collection('test_collection', {}) + response = client.post( + '/chat', + headers={'Content-Type': 'application/json'}, + content=dumps({ + "question": "How are you?", + "collection": "test_collection", + "mode": "invalid_mode" + }) + ) + assert response.status_code == 422 # Validation error + data = response.json() + assert 'mode' in str(data['detail']) # Verify error mentions mode field + + +def test_chat_invalid_collection(): + """Test POST /chat with non-existent collection""" + response = client.post( + '/chat', + headers={'Content-Type': 'application/json'}, + content=dumps({ + "question": "How are you?", + "mode": "rag" + }) + ) + assert response.status_code == 422 # Validation error + data = response.json() + assert 'Collection' in str(data['detail']) diff --git a/tests/test_chat_history.py b/tests/test_chat_history.py index 9398bb9..eb4122a 100644 --- a/tests/test_chat_history.py +++ b/tests/test_chat_history.py @@ -1,7 +1,6 @@ """chat_history module tests""" from datetime import datetime, timedelta import uuid -import pytest from brevia.chat_history import ( history, add_history, @@ -33,10 +32,15 @@ def test_add_history(): def test_add_history_failure(): - """Test history_from_db failure""" - with pytest.raises(ValueError) as exc: - add_history(uuid.uuid4(), 'test', 'who?', 'me') - assert str(exc.value) == 'Collection not found' + """Test add_history with non-existent collection""" + session_id = uuid.uuid4() + # When collection doesn't exist, it should + # still add the history but with collection_id = None + history_item = add_history(session_id, 'non_existent_collection', 'who?', 'me') + assert history_item is not None + assert history_item.collection_id is None + assert history_item.question == 'who?' + assert history_item.answer == 'me' def test_get_history(): diff --git a/tests/test_query.py b/tests/test_query.py index e926285..f35b947 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,12 +1,14 @@ """Query module tests""" import pytest from langchain.docstore.document import Document -from langchain_core.vectorstores import VectorStoreRetriever from langchain.retrievers.multi_query import MultiQueryRetriever +from langchain_core.vectorstores import VectorStoreRetriever +from langchain_core.messages.ai import AIMessage from langchain_core.runnables import Runnable from brevia.models import load_chatmodel from brevia.query import ( conversation_chain, + conversation_rag_chain, create_conversation_retriever, search_vector_qa, ChatParams, @@ -81,14 +83,79 @@ def test_search_vector_filter(): def test_conversation_chain(): - """Test conversation_chain function""" + """Test simple conversation_chain function without collection""" + chain = conversation_chain(chat_params=ChatParams()) + assert chain is not None + assert isinstance(chain, Runnable) + + # Test with streaming enabled + chain = conversation_chain( + chat_params=ChatParams(streaming=True), + answer_callbacks=[], + ) + assert chain is not None + assert isinstance(chain, Runnable) + + +def test_conversation_chain_output(): + """Test conversation_chain output format""" + chain = conversation_chain(chat_params=ChatParams()) + result = chain.invoke({ + 'question': 'What is 2+2?', + 'chat_history': [], + 'lang': '', + }) + + assert isinstance(result, AIMessage) + assert hasattr(result, 'content') + assert isinstance(result.content, str) + assert len(result.content) > 0 + + +def test_conversation_rag_chain(): + """Test RAG-based conversation_rag_chain function""" collection = create_collection('test', {}) - chain = conversation_chain(collection=collection, chat_params=ChatParams()) + chain = conversation_rag_chain( + collection=collection, + chat_params=ChatParams(), + ) + assert chain is not None + assert isinstance(chain, Runnable) + # Test with streaming enabled + chain = conversation_rag_chain( + collection=collection, + chat_params=ChatParams(streaming=True), + answer_callbacks=[], + ) assert chain is not None assert isinstance(chain, Runnable) +def test_conversation_rag_chain_with_docs(): + """Test conversation_rag_chain with documents and query""" + collection = create_collection('test', {}) + doc = Document(page_content='The answer to life is 42', metadata={}) + add_document(document=doc, collection_name='test') + + chain = conversation_rag_chain( + collection=collection, + chat_params=ChatParams(source_docs=True), + ) + + result = chain.invoke({ + 'question': 'What is the answer to life?', + 'chat_history': [], + 'lang': '', + }) + + assert isinstance(result, dict) + assert 'answer' in result + assert 'context' in result + assert isinstance(result['answer'], str) + assert isinstance(result['context'], list) + + def test_conversation_retriever(): """Test create_conversation_retriever function with multiquery""" collection = create_collection('test', {})