Skip to content
This repository has been archived by the owner on Jan 5, 2025. It is now read-only.

Commit

Permalink
Merge pull request #142 from codebanesr/enhancement/chain_selector
Browse files Browse the repository at this point in the history
Enhancement/chain selector
  • Loading branch information
codebanesr authored Aug 19, 2023
2 parents fb02a75 + 97848fa commit b8a8a1d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ If you want to switch from Pinecone to Qdrant, you can set the following environ
- `STORE`: The store to use to store embeddings. Can be `qdrant` or `pinecone`.


#### Optional [To modify the chat behaviour]

`CHAIN_TYPE` = The type of chain to use: `conversation_retrieval` | `retrieval_qa`

- `retrieval_qa` -> [Learn more](https://python.langchain.com/docs/use_cases/question_answering/how_to/vector_db_qa)
- `conversation_retrieval` -> [Learn more](https://python.langchain.com/docs/use_cases/question_answering/how_to/chat_vector_db)


> Note: for pincone db, make sure that the dimension is equal to 1536
- Navigate to the repository folder and run the following command (for MacOS or Linux):
Expand Down
2 changes: 1 addition & 1 deletion dj_backend_server/api/utils/make_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def getRetrievalQAWithSourcesChain(vector_store: VectorStore, mode, initial_prom
return chain


def getConversationRetrievalChain(vector_store: VectorStore, mode, initial_prompt: str, memory_key: str):
def getConversationRetrievalChain(vector_store: VectorStore, mode, initial_prompt: str):
llm = get_llm()
template = get_qa_prompt_by_mode(mode, initial_prompt=initial_prompt)
prompt = PromptTemplate.from_template(template)
Expand Down
30 changes: 23 additions & 7 deletions dj_backend_server/api/views/views_chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from django.http import JsonResponse
from django.views.decorators.http import require_POST
from langchain import QAWithSourcesChain

from api.utils import get_vector_store
from api.utils.make_chain import getConversationRetrievalChain
from api.utils.make_chain import getConversationRetrievalChain, getRetrievalQAWithSourcesChain
import json
from django.views.decorators.csrf import csrf_exempt
from api.interfaces import StoreOptions
Expand All @@ -13,6 +14,10 @@
import logging
import traceback
from web.services.chat_history_service import get_chat_history_for_retrieval_chain
import os

from dotenv import load_dotenv
load_dotenv()

logger = logging.getLogger(__name__)

Expand All @@ -36,12 +41,8 @@ def chat(request):
sanitized_question = question.strip().replace('\n', ' ')

vector_store = get_vector_store(StoreOptions(namespace=namespace))
chain = getConversationRetrievalChain(vector_store, mode, initial_prompt, memory_key=session_id)

# To avoid fetching an excessively large amount of history data from the database, set a limit on the maximum number of records that can be retrieved in a single query.
chat_history = get_chat_history_for_retrieval_chain(session_id, limit=40)
response = chain({"question": sanitized_question, "chat_history": chat_history }, return_only_outputs=True)
response_text = response['answer']
response_text = get_completion_response(vector_store=vector_store, initial_prompt=initial_prompt,mode=mode, sanitized_question=sanitized_question, session_id=session_id)

ChatHistory.objects.bulk_create([
ChatHistory(
Expand All @@ -68,4 +69,19 @@ def chat(request):
except Exception as e:
logger.error(str(e))
logger.error(traceback.format_exc())
return JsonResponse({'error': 'An error occurred'}, status=500)
return JsonResponse({'error': 'An error occurred'}, status=500)


def get_completion_response(vector_store, mode, initial_prompt, sanitized_question, session_id):
chain_type = os.getenv("CHAIN_TYPE", "conversation_retrieval")
chain: QAWithSourcesChain
if chain_type == 'retrieval_qa':
chain = getRetrievalQAWithSourcesChain(vector_store, mode, initial_prompt)
response = chain({"question": sanitized_question}, return_only_outputs=True)
response_text = response['answer']
elif chain_type == 'conversation_retrieval':
chain = getConversationRetrievalChain(vector_store, mode, initial_prompt)
chat_history = get_chat_history_for_retrieval_chain(session_id, limit=40)
response = chain({"question": sanitized_question, "chat_history": chat_history}, return_only_outputs=True)
response_text = response['answer']
return response_text

0 comments on commit b8a8a1d

Please sign in to comment.