Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion brevia/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import logging
import json
from glom import glom
from langchain_community.callbacks import OpenAICallbackHandler, get_openai_callback
from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.documents import Document
Expand Down Expand Up @@ -81,7 +82,11 @@ async def on_chain_end(
) -> None:
"""Run when chain ends running."""
if parent_run_id is None:
self.answer = outputs.get('answer')
self.answer = glom(
outputs,
'answer',
default=glom(outputs, 'content', default='')
)
self.chain_ended.set()

async def wait_conversation_done(self):
Expand Down
10 changes: 6 additions & 4 deletions brevia/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,14 @@ def add_history(
return None

with Session(db_connection()) as session:
collection_store = CollectionStore.get_by_name(session, collection)
if not collection_store:
raise ValueError("Collection not found")
collection_uuid = None # Default empty UUID
if collection:
collection_store = CollectionStore.get_by_name(session, collection)
collection_uuid = collection_store.uuid if collection_store else None

chat_history_store = ChatHistoryStore(
session_id=session_id,
collection_id=collection_store.uuid,
collection_id=collection_uuid,
question=question,
answer=answer,
cmetadata=metadata,
Expand Down
7 changes: 5 additions & 2 deletions brevia/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ def get_token_ids(self, text: str) -> list[int]:
return [10] * 10


LOREM_IPSUM = """Lorem ipsum dolor sit amet, consectetur adipisici elit,
sed eiusmod tempor incidunt ut labore et dolore magna aliqua."""
LOREM_IPSUM = """{
"question": "What is lorem ipsum?",
"answer": "Lorem ipsum dolor sit amet, consectetur adipisici elit,
sed eiusmod tempor incidunt ut labore et dolore magna aliqua."
}"""


def load_llm(config: dict) -> BaseLLM:
Expand Down
44 changes: 43 additions & 1 deletion brevia/postman/Brevia API.postman_collection.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,49 @@
],
"body": {
"mode": "raw",
"raw": "{\n \"question\" : \"{{query}}\",\n \"collection\" : \"{{collection}}\"\n}"
"raw": "{\n \"question\" : \"{{query}}\",\n \"mode\": \"rag\",\n \"collection\" : \"{{collection}}\"\n}"
},
"url": {
"raw": "{{baseUrl}}/chat",
"host": [
"{{baseUrl}}"
],
"path": [
"chat"
]
}
},
"response": []
},
{
"name": "chat - conversation, no RAG",
"request": {
"auth": {
"type": "bearer",
"bearer": [
{
"key": "token",
"value": "{{access_token}}",
"type": "string"
}
]
},
"method": "POST",
"header": [
{
"key": "Content-Type",
"value": "application/json",
"type": "text"
},
{
"key": "X-Chat-Session",
"value": "{{session_id}}",
"type": "text"
}
],
"body": {
"mode": "raw",
"raw": "{\n \"question\" : \"{{query}}\",\n \"mode\": \"conversation\",\n \"streaming\": false,\n \"token_data\": true\n}"
},
"url": {
"raw": "{{baseUrl}}/chat",
Expand Down
64 changes: 58 additions & 6 deletions brevia/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from langchain_community.vectorstores.pgembedding import CollectionStore
from langchain_community.vectorstores.pgvector import DistanceStrategy, PGVector
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain_core.language_models import BaseChatModel
from langchain_core.documents import Document
from pydantic import BaseModel
from pydantic import BaseModel, Field
from brevia.connection import connection_string
from brevia.collections import single_collection_by_name
from brevia.models import load_chatmodel, load_embeddings
Expand Down Expand Up @@ -173,7 +174,7 @@ def create_conversation_retriever(
search_kwargs = chat_params.get_search_kwargs()
retriever_conf = collection.cmetadata.get(
'qa_retriever',
get_settings().qa_retriever.copy()
dict(get_settings().qa_retriever).copy()
)
if not retriever_conf:
return create_default_retriever(
Expand All @@ -189,13 +190,13 @@ def create_conversation_retriever(
document_search, search_kwargs, retriever_conf)


def conversation_chain(
def conversation_rag_chain(
collection: CollectionStore,
chat_params: ChatParams,
answer_callbacks: list[BaseCallbackHandler] | None = None,
) -> Chain:
"""
Create and return a conversation chain for Q&A with embedded dataset knowledge.
Create and return a conversation chain for Q&A with embedded dataset knowledge.(RAG)

Args:
collection (CollectionStore): The collection store item containing the dataset.
Expand Down Expand Up @@ -234,7 +235,7 @@ def conversation_chain(
# Main LLM configuration
qa_llm_conf = collection.cmetadata.get(
'qa_completion_llm',
settings.qa_completion_llm.copy()
dict(settings.qa_completion_llm).copy()
)
qa_llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks
qa_llm_conf['streaming'] = chat_params.streaming
Expand All @@ -250,7 +251,7 @@ def conversation_chain(
# Chain to rewrite question with history
fup_llm_conf = collection.cmetadata.get(
'qa_followup_llm',
settings.qa_followup_llm.copy()
dict(settings.qa_followup_llm).copy()
)
fup_llm = load_chatmodel(fup_llm_conf)
fup_chain = (
Expand Down Expand Up @@ -279,3 +280,54 @@ def conversation_chain(
)
| retrivial_chain
)


def conversation_chain(
chat_params: ChatParams,
answer_callbacks: list[BaseCallbackHandler] | None = None,
) -> Chain:
"""
Create a simple conversation chain for conversation tasks without a collection.
This chain is used for general chat interactions that do not involve a specific
collection or dataset.
"""

# Define your desired data structure.
class Result(BaseModel):
Copy link
Contributor

@stefanorosanelli stefanorosanelli May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this Result class? seems unused now

"""
Result model for the conversation chain output.
Attributes:
question (str): The question asked in the conversation.
answer (str): The answer provided in response to the question.
"""
question: str = Field(description="The question asked in the conversation.")
answer: str = Field(description="The answer to the question.")

settings = get_settings()

# Chain to rewrite question with history
fup_llm_conf = dict(settings.qa_followup_llm).copy()
fup_llm = load_chatmodel(fup_llm_conf)
fup_chain = (
load_condense_prompt()
| fup_llm
| StrOutputParser()
)

llm_conf = dict(settings.qa_completion_llm).copy()
llm_conf['callbacks'] = [] if answer_callbacks is None else answer_callbacks
llm_conf['streaming'] = chat_params.streaming

prompt = PromptTemplate(
input_variables=["question"],
template="\n{question}",
)
llm = load_chatmodel(llm_conf)
chain = prompt | llm

return (
RunnablePassthrough.assign(
question=fup_chain
)
| chain
)
77 changes: 59 additions & 18 deletions brevia/routers/qa_router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""API endpoints for question answering and search"""
from typing import Annotated
import asyncio
from glom import glom
from typing import Annotated
from typing_extensions import Self
from pydantic import Field, model_validator
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains.base import Chain
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages.ai import AIMessage
from fastapi import APIRouter, Header
from fastapi.responses import StreamingResponse
from brevia import chat_history
Expand All @@ -17,7 +21,13 @@
token_usage,
TokensCallbackHandler,
)
from brevia.query import SearchQuery, ChatParams, conversation_chain, search_vector_qa
from brevia.query import (
SearchQuery,
ChatParams,
conversation_chain,
conversation_rag_chain,
search_vector_qa,
)
from brevia.models import test_models_in_use

router = APIRouter()
Expand All @@ -26,32 +36,60 @@
class ChatBody(ChatParams):
""" /chat request body """
question: str
collection: str
collection: str | None = None
chat_history: list = []
chat_lang: str | None = None
mode: str = Field(pattern='^(rag|conversation)$', default='rag')
token_data: bool = False

@model_validator(mode='after')
def check_collection(self) -> Self:
"""Validate collection and mode"""
if not self.collection and self.mode == 'rag':
raise ValueError('Collection required for rag mode')
return self


@router.post('/prompt', dependencies=get_dependencies(), deprecated=True, tags=['Chat'])
@router.post('/chat', dependencies=get_dependencies(), tags=['Chat'])
async def chat_action(
chat_body: ChatBody,
x_chat_session: Annotated[str | None, Header()] = None,
):
""" /chat endpoint, ask chatbot about a collection of documents """
collection = check_collection_name(chat_body.collection)
if not collection.cmetadata:
collection.cmetadata = {}
lang = chat_language(chat_body=chat_body, cmetadata=collection.cmetadata)

"""
/chat endpoint, ask chatbot about a collection of documents to perform a rag chat.
If collection is not provided, it will use a simple completion chain.
"""
# Check if collection is provided and valid
collection = None
if chat_body.collection:
collection = check_collection_name(chat_body.collection)
if not collection.cmetadata:
collection.cmetadata = {}

lang = chat_language(
chat_body=chat_body,
cmetadata=collection.cmetadata if collection else {}
)
conversation_handler = ConversationCallbackHandler()
stream_handler = AsyncIteratorCallbackHandler()
chain = conversation_chain(
collection=collection,
chat_params=ChatParams(**chat_body.model_dump()),
answer_callbacks=[stream_handler] if chat_body.streaming else [],
)
embeddings = collection.cmetadata.get('embeddings', None)

# Select chain based on chat_body.mode
if chat_body.mode == 'rag':
# RAG-based conversation chain using collection context
chain = conversation_rag_chain(
collection=collection,
chat_params=ChatParams(**chat_body.model_dump()),
answer_callbacks=[stream_handler] if chat_body.streaming else [],
)
embeddings = collection.cmetadata.get('embeddings', None)
elif chat_body.mode == 'conversation':
# Test mode - currently same as simple conversation
chain = conversation_chain(
chat_params=ChatParams(**chat_body.model_dump()),
answer_callbacks=[stream_handler] if chat_body.streaming else [],
)
embeddings = None

with token_usage_callback() as token_callback:
if not chat_body.streaming or test_models_in_use():
Expand Down Expand Up @@ -157,13 +195,14 @@ async def run_chain(


def chat_result(
result: dict,
result: dict | AIMessage,
callb: TokensCallbackHandler,
chat_body: ChatBody,
x_chat_session: str | None = None,
) -> dict:
""" Handle chat result: save chat history and return answer """
answer = result['answer'].strip(" \n")
answer = glom(result, 'answer', default=glom(result, 'content', default=''))
answer = str(answer).strip(" \n")

chat_history_id = None
if not chat_body.streaming:
Expand All @@ -176,9 +215,11 @@ def chat_result(
)
chat_history_id = None if chat_hist is None else str(chat_hist.uuid)

context = result['context'] if 'context' in result else None

return {
'bot': answer,
'docs': None if not chat_body.source_docs else result['context'],
'docs': None if not chat_body.source_docs else context,
'chat_history_id': chat_history_id,
'token_data': None if not chat_body.token_data else token_usage(callb)
}
Expand Down
Loading