|
| 1 | +"""Chat langchain 'engine'.""" |
| 2 | +import os |
| 3 | +from operator import itemgetter |
| 4 | +from typing import Dict, List, Optional, Sequence |
| 5 | + |
| 6 | +import weaviate |
| 7 | +from langchain.chat_models import ChatOpenAI |
| 8 | +from langchain.embeddings import OpenAIEmbeddings |
| 9 | +from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate |
| 10 | +from langchain.schema import Document |
| 11 | +from langchain.schema.embeddings import Embeddings |
| 12 | +from langchain.schema.language_model import BaseLanguageModel |
| 13 | +from langchain.schema.messages import AIMessage, HumanMessage |
| 14 | +from langchain.schema.output_parser import StrOutputParser |
| 15 | +from langchain.schema.retriever import BaseRetriever |
| 16 | +from langchain.schema.runnable import ( |
| 17 | + Runnable, |
| 18 | + RunnableBranch, |
| 19 | + RunnableLambda, |
| 20 | + RunnableMap, |
| 21 | +) |
| 22 | +from langchain.vectorstores import Weaviate |
| 23 | +from pydantic import BaseModel |
| 24 | + |
| 25 | +from chat_langchain_engine.constants import WEAVIATE_DOCS_INDEX_NAME |
| 26 | +from voyage import VoyageEmbeddings |
| 27 | + |
| 28 | +RESPONSE_TEMPLATE = """\ |
| 29 | +You are an expert programmer and problem-solver, tasked with answering any question \ |
| 30 | +about Langchain. |
| 31 | +
|
| 32 | +Generate a comprehensive and informative answer of 80 words or less for the \ |
| 33 | +given question based solely on the provided search results (URL and content). You must \ |
| 34 | +only use information from the provided search results. Use an unbiased and \ |
| 35 | +journalistic tone. Combine search results together into a coherent answer. Do not \ |
| 36 | +repeat text. Cite search results using [${{number}}] notation. Only cite the most \ |
| 37 | +relevant results that answer the question accurately. Place these citations at the end \ |
| 38 | +of the sentence or paragraph that reference them - do not put them all at the end. If \ |
| 39 | +different results refer to different entities within the same name, write separate \ |
| 40 | +answers for each entity. |
| 41 | +
|
| 42 | +You should use bullet points in your answer for readability. Put citations where they apply |
| 43 | +rather than putting them all at the end. |
| 44 | +
|
| 45 | +If there is nothing in the context relevant to the question at hand, just say "Hmm, \ |
| 46 | +I'm not sure." Don't try to make up an answer. |
| 47 | +
|
| 48 | +Anything between the following `context` html blocks is retrieved from a knowledge \ |
| 49 | +bank, not part of the conversation with the user. |
| 50 | +
|
| 51 | +<context> |
| 52 | + {context} |
| 53 | +<context/> |
| 54 | +
|
| 55 | +REMEMBER: If there is no relevant information within the context, just say "Hmm, I'm \ |
| 56 | +not sure." Don't try to make up an answer. Anything between the preceding 'context' \ |
| 57 | +html blocks is retrieved from a knowledge bank, not part of the conversation with the \ |
| 58 | +user.\ |
| 59 | +""" |
| 60 | + |
| 61 | +REPHRASE_TEMPLATE = """\ |
| 62 | +Given the following conversation and a follow up question, rephrase the follow up \ |
| 63 | +question to be a standalone question. |
| 64 | +
|
| 65 | +Chat History: |
| 66 | +{chat_history} |
| 67 | +Follow Up Input: {question} |
| 68 | +Standalone Question:""" |
| 69 | + |
| 70 | + |
| 71 | +WEAVIATE_URL = os.environ["WEAVIATE_URL"] |
| 72 | +WEAVIATE_API_KEY = os.environ["WEAVIATE_API_KEY"] |
| 73 | + |
| 74 | + |
| 75 | +class ChatRequest(BaseModel): |
| 76 | + question: str |
| 77 | + chat_history: Optional[List[Dict[str, str]]] |
| 78 | + |
| 79 | + |
| 80 | +def get_embeddings_model() -> Embeddings: |
| 81 | + if os.environ.get("VOYAGE_AI_URL") and os.environ.get("VOYAGE_AI_MODEL"): |
| 82 | + return VoyageEmbeddings() |
| 83 | + return OpenAIEmbeddings(chunk_size=200) |
| 84 | + |
| 85 | + |
| 86 | +def get_retriever() -> BaseRetriever: |
| 87 | + weaviate_client = weaviate.Client( |
| 88 | + url=WEAVIATE_URL, |
| 89 | + auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY), |
| 90 | + ) |
| 91 | + weaviate_client = Weaviate( |
| 92 | + client=weaviate_client, |
| 93 | + index_name=WEAVIATE_DOCS_INDEX_NAME, |
| 94 | + text_key="text", |
| 95 | + embedding=get_embeddings_model(), |
| 96 | + by_text=False, |
| 97 | + attributes=["source", "title"], |
| 98 | + ) |
| 99 | + return weaviate_client.as_retriever(search_kwargs=dict(k=6)) |
| 100 | + |
| 101 | + |
| 102 | +def create_retriever_chain( |
| 103 | + llm: BaseLanguageModel, retriever: BaseRetriever |
| 104 | +) -> Runnable: |
| 105 | + CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE) |
| 106 | + condense_question_chain = ( |
| 107 | + CONDENSE_QUESTION_PROMPT | llm | StrOutputParser() |
| 108 | + ).with_config( |
| 109 | + run_name="CondenseQuestion", |
| 110 | + ) |
| 111 | + conversation_chain = condense_question_chain | retriever |
| 112 | + return RunnableBranch( |
| 113 | + ( |
| 114 | + RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config( |
| 115 | + run_name="HasChatHistoryCheck" |
| 116 | + ), |
| 117 | + conversation_chain.with_config(run_name="RetrievalChainWithHistory"), |
| 118 | + ), |
| 119 | + ( |
| 120 | + RunnableLambda(itemgetter("question")).with_config( |
| 121 | + run_name="Itemgetter:question" |
| 122 | + ) |
| 123 | + | retriever |
| 124 | + ).with_config(run_name="RetrievalChainWithNoHistory"), |
| 125 | + ).with_config(run_name="RouteDependingOnChatHistory") |
| 126 | + |
| 127 | + |
| 128 | +def format_docs(docs: Sequence[Document]) -> str: |
| 129 | + formatted_docs = [] |
| 130 | + for i, doc in enumerate(docs): |
| 131 | + doc_string = f"<doc id='{i}'>{doc.page_content}</doc>" |
| 132 | + formatted_docs.append(doc_string) |
| 133 | + return "\n".join(formatted_docs) |
| 134 | + |
| 135 | + |
| 136 | +def serialize_history(request: ChatRequest): |
| 137 | + chat_history = request["chat_history"] or [] |
| 138 | + converted_chat_history = [] |
| 139 | + for message in chat_history: |
| 140 | + if message.get("human") is not None: |
| 141 | + converted_chat_history.append(HumanMessage(content=message["human"])) |
| 142 | + if message.get("ai") is not None: |
| 143 | + converted_chat_history.append(AIMessage(content=message["ai"])) |
| 144 | + return converted_chat_history |
| 145 | + |
| 146 | + |
| 147 | +def create_chain( |
| 148 | + llm: BaseLanguageModel, |
| 149 | + retriever: BaseRetriever, |
| 150 | +) -> Runnable: |
| 151 | + retriever_chain = create_retriever_chain( |
| 152 | + llm, |
| 153 | + retriever, |
| 154 | + ).with_config(run_name="FindDocs") |
| 155 | + _context = RunnableMap( |
| 156 | + { |
| 157 | + "context": retriever_chain | format_docs, |
| 158 | + "question": itemgetter("question"), |
| 159 | + "chat_history": itemgetter("chat_history"), |
| 160 | + } |
| 161 | + ).with_config(run_name="RetrieveDocs") |
| 162 | + prompt = ChatPromptTemplate.from_messages( |
| 163 | + [ |
| 164 | + ("system", RESPONSE_TEMPLATE), |
| 165 | + MessagesPlaceholder(variable_name="chat_history"), |
| 166 | + ("human", "{question}"), |
| 167 | + ] |
| 168 | + ) |
| 169 | + |
| 170 | + response_synthesizer = (prompt | llm | StrOutputParser()).with_config( |
| 171 | + run_name="GenerateResponse", |
| 172 | + ) |
| 173 | + return ( |
| 174 | + { |
| 175 | + "question": RunnableLambda(itemgetter("question")).with_config( |
| 176 | + run_name="Itemgetter:question" |
| 177 | + ), |
| 178 | + "chat_history": RunnableLambda(serialize_history).with_config( |
| 179 | + run_name="SerializeHistory" |
| 180 | + ), |
| 181 | + } |
| 182 | + | _context |
| 183 | + | response_synthesizer |
| 184 | + ) |
| 185 | + |
| 186 | + |
| 187 | +llm = ChatOpenAI( |
| 188 | + model="gpt-3.5-turbo-16k", |
| 189 | + streaming=True, |
| 190 | + temperature=0, |
| 191 | +) |
| 192 | +retriever = get_retriever() |
| 193 | +answer_chain = create_chain( |
| 194 | + llm, |
| 195 | + retriever, |
| 196 | +) |
0 commit comments