-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathStreamlitConversation.py
79 lines (64 loc) · 2.67 KB
/
StreamlitConversation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# https://streamlit.io/
# https://python.langchain.com/en/latest/modules/memory/types/buffer_window.html
from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate
)
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferWindowMemory
import streamlit as st
from streamlit_chat import message
from dotenv import load_dotenv
import os
system_message = "You are a helpful assistant."
k=2
conversation_key = "conversation"
human_message_key = "human"
def get_api_key():
return st.text_input(label="OpenAI API Key ", type="password", placeholder="Ex: sk-2twm...", key="openai_api_key_input")
def getConversation():
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(system_message),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}")
])
llm = ChatOpenAI(temperature=0)
memory = ConversationBufferWindowMemory(k=k, return_messages=True)
conversation = ConversationChain(
memory=memory, prompt=prompt, llm=llm, verbose=True
)
return conversation
def submit():
user_input = st.session_state.user_input
st.session_state.user_input = ''
if (len(user_input) > 1):
conversation = st.session_state[conversation_key]
conversation.predict(input=user_input)
def main():
st.set_page_config(page_title="Conversation Buffer Window Memory", page_icon=":robot:")
st.title("Conversation")
st.markdown(f"System Message: {system_message}")
st.header(f"Buffer Window Memory k={k}")
load_dotenv()
# Load the OpenAI API key from the environment variable
if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_API_KEY") == "":
print("OPENAI_API_KEY is not set")
os.environ["OPENAI_API_KEY"] = get_api_key()
if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_API_KEY") == "":
return
placeholder = st.empty()
if conversation_key not in st.session_state:
st.session_state[conversation_key] = getConversation()
conversation = st.session_state[conversation_key]
with placeholder.container():
for index, msg in enumerate(conversation.memory.chat_memory.messages):
if msg.type == human_message_key:
message(msg.content, is_user=True, key=f"msg{index}")
else:
message(msg.content, key=f"msg{index}")
st.text_input(label="Enter your message", placeholder="Send a message", key="user_input", on_change=submit)
if __name__ == '__main__':
main()