Skip to content

Commit

Permalink
Move Chain to its own directory (#208)
Browse files Browse the repository at this point in the history
* Move Chain

* Bump langserve

* Optionally include callback events (though currently may not work)
  • Loading branch information
hinthornw authored Nov 1, 2023
1 parent 5295081 commit 0cb9d98
Show file tree
Hide file tree
Showing 7 changed files with 738 additions and 775 deletions.
3 changes: 3 additions & 0 deletions chat_langchain_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from chat_langchain_engine.chain import answer_chain, ChatRequest

__all__ = ["answer_chain", "ChatRequest"]
196 changes: 196 additions & 0 deletions chat_langchain_engine/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""Chat langchain 'engine'."""
import os
from operator import itemgetter
from typing import Dict, List, Optional, Sequence

import weaviate
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain.schema import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, HumanMessage
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
Runnable,
RunnableBranch,
RunnableLambda,
RunnableMap,
)
from langchain.vectorstores import Weaviate
from pydantic import BaseModel

from chat_langchain_engine.constants import WEAVIATE_DOCS_INDEX_NAME
from voyage import VoyageEmbeddings

RESPONSE_TEMPLATE = """\
You are an expert programmer and problem-solver, tasked with answering any question \
about Langchain.
Generate a comprehensive and informative answer of 80 words or less for the \
given question based solely on the provided search results (URL and content). You must \
only use information from the provided search results. Use an unbiased and \
journalistic tone. Combine search results together into a coherent answer. Do not \
repeat text. Cite search results using [${{number}}] notation. Only cite the most \
relevant results that answer the question accurately. Place these citations at the end \
of the sentence or paragraph that reference them - do not put them all at the end. If \
different results refer to different entities within the same name, write separate \
answers for each entity.
You should use bullet points in your answer for readability. Put citations where they apply
rather than putting them all at the end.
If there is nothing in the context relevant to the question at hand, just say "Hmm, \
I'm not sure." Don't try to make up an answer.
Anything between the following `context` html blocks is retrieved from a knowledge \
bank, not part of the conversation with the user.
<context>
{context}
<context/>
REMEMBER: If there is no relevant information within the context, just say "Hmm, I'm \
not sure." Don't try to make up an answer. Anything between the preceding 'context' \
html blocks is retrieved from a knowledge bank, not part of the conversation with the \
user.\
"""

REPHRASE_TEMPLATE = """\
Given the following conversation and a follow up question, rephrase the follow up \
question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone Question:"""


WEAVIATE_URL = os.environ["WEAVIATE_URL"]
WEAVIATE_API_KEY = os.environ["WEAVIATE_API_KEY"]


class ChatRequest(BaseModel):
question: str
chat_history: Optional[List[Dict[str, str]]]


def get_embeddings_model() -> Embeddings:
if os.environ.get("VOYAGE_AI_URL") and os.environ.get("VOYAGE_AI_MODEL"):
return VoyageEmbeddings()
return OpenAIEmbeddings(chunk_size=200)


def get_retriever() -> BaseRetriever:
weaviate_client = weaviate.Client(
url=WEAVIATE_URL,
auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY),
)
weaviate_client = Weaviate(
client=weaviate_client,
index_name=WEAVIATE_DOCS_INDEX_NAME,
text_key="text",
embedding=get_embeddings_model(),
by_text=False,
attributes=["source", "title"],
)
return weaviate_client.as_retriever(search_kwargs=dict(k=6))


def create_retriever_chain(
llm: BaseLanguageModel, retriever: BaseRetriever
) -> Runnable:
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)
condense_question_chain = (
CONDENSE_QUESTION_PROMPT | llm | StrOutputParser()
).with_config(
run_name="CondenseQuestion",
)
conversation_chain = condense_question_chain | retriever
return RunnableBranch(
(
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
run_name="HasChatHistoryCheck"
),
conversation_chain.with_config(run_name="RetrievalChainWithHistory"),
),
(
RunnableLambda(itemgetter("question")).with_config(
run_name="Itemgetter:question"
)
| retriever
).with_config(run_name="RetrievalChainWithNoHistory"),
).with_config(run_name="RouteDependingOnChatHistory")


def format_docs(docs: Sequence[Document]) -> str:
formatted_docs = []
for i, doc in enumerate(docs):
doc_string = f"<doc id='{i}'>{doc.page_content}</doc>"
formatted_docs.append(doc_string)
return "\n".join(formatted_docs)


def serialize_history(request: ChatRequest):
chat_history = request["chat_history"] or []
converted_chat_history = []
for message in chat_history:
if message.get("human") is not None:
converted_chat_history.append(HumanMessage(content=message["human"]))
if message.get("ai") is not None:
converted_chat_history.append(AIMessage(content=message["ai"]))
return converted_chat_history


def create_chain(
llm: BaseLanguageModel,
retriever: BaseRetriever,
) -> Runnable:
retriever_chain = create_retriever_chain(
llm,
retriever,
).with_config(run_name="FindDocs")
_context = RunnableMap(
{
"context": retriever_chain | format_docs,
"question": itemgetter("question"),
"chat_history": itemgetter("chat_history"),
}
).with_config(run_name="RetrieveDocs")
prompt = ChatPromptTemplate.from_messages(
[
("system", RESPONSE_TEMPLATE),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)

response_synthesizer = (prompt | llm | StrOutputParser()).with_config(
run_name="GenerateResponse",
)
return (
{
"question": RunnableLambda(itemgetter("question")).with_config(
run_name="Itemgetter:question"
),
"chat_history": RunnableLambda(serialize_history).with_config(
run_name="SerializeHistory"
),
}
| _context
| response_synthesizer
)


llm = ChatOpenAI(
model="gpt-3.5-turbo-16k",
streaming=True,
temperature=0,
)
retriever = get_retriever()
answer_chain = create_chain(
llm,
retriever,
)
File renamed without changes.
2 changes: 1 addition & 1 deletion ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langchain.utils.html import PREFIXES_TO_IGNORE_REGEX, SUFFIXES_TO_IGNORE_REGEX
from langchain.vectorstores import Weaviate

from constants import WEAVIATE_DOCS_INDEX_NAME
from chat_langchain_engine.constants import WEAVIATE_DOCS_INDEX_NAME
from voyage import VoyageEmbeddings

logger = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit 0cb9d98

Please sign in to comment.