diff --git a/query_data.py b/query_data.py index c0028317f..1a137ae7f 100644 --- a/query_data.py +++ b/query_data.py @@ -1,7 +1,7 @@ """Create a ChatVectorDBChain for question/answering.""" from langchain.callbacks.base import AsyncCallbackManager from langchain.callbacks.tracers import LangChainTracer -from langchain.chains import ChatVectorDBChain +from langchain.chains import ConversationalRetrievalChain from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT, QA_PROMPT) from langchain.chains.llm import LLMChain @@ -12,9 +12,9 @@ def get_chain( vectorstore: VectorStore, question_handler, stream_handler, tracing: bool = False -) -> ChatVectorDBChain: - """Create a ChatVectorDBChain for question/answering.""" - # Construct a ChatVectorDBChain with a streaming llm for combine docs +) -> ConversationalRetrievalChain: + """Create a ConversationalRetrievalChain for question/answering.""" + # Construct a ConversationalRetrievalChain with a streaming llm for combine docs # and a separate, non-streaming llm for question generation manager = AsyncCallbackManager([]) question_manager = AsyncCallbackManager([question_handler]) @@ -45,10 +45,11 @@ def get_chain( streaming_llm, chain_type="stuff", prompt=QA_PROMPT, callback_manager=manager ) - qa = ChatVectorDBChain( - vectorstore=vectorstore, + qa = ConversationalRetrievalChain( + retriever=vectorstore.as_retriever(), combine_docs_chain=doc_chain, question_generator=question_generator, callback_manager=manager, + verbose=True ) return qa