Skip to content

Commit 22dad76

Browse files
Initial commit
new file: .gitignore new file: README.md new file: app.py new file: datavectoriser.py
0 parents  commit 22dad76

File tree

4 files changed

+254
-0
lines changed

4 files changed

+254
-0
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.venv/*
2+
__pycache__/*
3+
myapp.py

README.md

Whitespace-only changes.

app.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import os
2+
import shutil
3+
from dotenv import load_dotenv, find_dotenv
4+
from pathlib import Path
5+
from langchain.document_loaders.text import TextLoader
6+
from langchain.document_loaders.directory import DirectoryLoader
7+
from langchain.prompts import PromptTemplate
8+
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
9+
from langchain.vectorstores.lancedb import LanceDB
10+
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
11+
12+
from langchain.chains.combine_documents import create_stuff_documents_chain
13+
from langchain.chains import create_retrieval_chain
14+
from langchain import hub
15+
16+
from io import BytesIO
17+
import PyPDF2 # For PDF handling
18+
from docx import Document
19+
import streamlit as st
20+
import fitz
21+
import validators
22+
import warnings
23+
from datavectoriser import *
24+
warnings.filterwarnings("ignore")
25+
26+
load_dotenv(find_dotenv('.venv\.env')) # read local .env file
27+
28+
vectorstore_path = "vectorstore/db_lancedb" # Get the embeddings path
29+
30+
custom_prompt_template = """Use the following information to answer the users' questions, if you dont know the answer just say "I don't know the answer". DO NOT make up answers that are not based on facts. Explain with detailed answers that are easy to understand. You are free to draw inferences based on the information provided in the context in order to answer the questions as best as possible.
31+
32+
Context: {context}
33+
Question: {question}
34+
35+
Only return the useful aspects of the answer below and nothing else.
36+
Helpful answer:
37+
"""
38+
39+
def set_custom_prompt():
40+
"""
41+
Prompt template for QA retrieval for each vector store, we also pass in context and question.
42+
"""
43+
prompt = PromptTemplate(template= custom_prompt_template, input_variables=['context','question'])
44+
return prompt
45+
46+
def load_llm():
47+
"""
48+
Loading the llama2 model we have installed using CTransformers
49+
"""
50+
llm = ChatGoogleGenerativeAI(
51+
model= "gemini-pro",
52+
max_output_tokens = 512,
53+
temperature = 0.5,
54+
convert_system_message_to_human=True
55+
)
56+
return llm
57+
58+
def retrieval_qa_chain(llm,prompt,db):
59+
"""
60+
Setting up a retrieval-based question-answering chain,
61+
and returning response
62+
"""
63+
retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
64+
65+
combine_docs_chain = create_stuff_documents_chain(
66+
llm, retrieval_qa_chat_prompt
67+
)
68+
69+
retriever = db.as_retriever(search_kwargs = {'k': 2})
70+
71+
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)
72+
73+
return retrieval_chain
74+
75+
def qa_bot():
76+
"""
77+
Loading the db and using it in retrieval_qa_chain
78+
"""
79+
embeddings = GoogleGenerativeAIEmbeddings(model = 'models/embedding-001')
80+
81+
db = lancedb.connect(vectorstore_path)
82+
table = db.open_table("vector_table")
83+
vectorstore_db = LanceDB(table, embedding=embeddings)
84+
llm = load_llm()
85+
qa_prompt = set_custom_prompt()
86+
qa = retrieval_qa_chain(llm,qa_prompt,vectorstore_db)
87+
return qa
88+
89+
@st.cache_resource(show_spinner=False)
90+
def final_result(_qa_result, query):
91+
response = _qa_result.invoke({'input':query})
92+
return response
93+
94+
def handle_uploads(uploaded_file):
95+
file_extension = uploaded_file.name.split(".")[-1].lower()
96+
try:
97+
if file_extension == "pdf":
98+
file_bytes = uploaded_file.read()
99+
100+
# Use BytesIO to simulate a file-like object
101+
pdf_reader = PyPDF2.PdfReader(BytesIO(file_bytes))
102+
# pdf_reader = fitz
103+
104+
# Extract text from all pages
105+
text = ""
106+
for page in pdf_reader.pages:
107+
text += page.extract_text() + "\n"
108+
109+
return text
110+
elif file_extension == "docx":
111+
file_bytes = uploaded_file.read()
112+
113+
# Use BytesIO to simulate a file-like object
114+
document = Document(BytesIO(file_bytes))
115+
116+
# Extract text as before
117+
text = ""
118+
for paragraph in document.paragraphs:
119+
text += paragraph.text + "\n"
120+
121+
return text
122+
except Exception as e:
123+
return False
124+
125+
@st.cache_data(show_spinner=False)
126+
def process_uploaded_files(uploaded_files, data_store):
127+
for file in uploaded_files:
128+
allowed_extensions = {"pdf", "docx"}
129+
file_extension = file.name.split(".")[-1].lower()
130+
131+
if file_extension not in allowed_extensions:
132+
print(f"{file.name}: Unsupported file type. Please upload a PDF or DOCX file.")
133+
continue
134+
135+
new_file = file.name.split('.')[0]+'.txt'
136+
text = handle_uploads(file)
137+
if text:
138+
with open(os.path.join(data_store, new_file), "w", encoding='utf-8-sig', errors='ignore') as f:
139+
f.write(text, )
140+
else:
141+
print("Cannot process the document.")
142+
143+
flag = create_vector_db(data_store, vectorstore_path)
144+
145+
if flag:
146+
return True
147+
else:
148+
st.error(f"Error in vectorising the data: {flag}")
149+
return False
150+
151+
152+
def main():
153+
data_store = "data"
154+
os.makedirs(data_store, exist_ok =True)
155+
156+
# Function to clear the uploaded file data
157+
def clear_uploaded_file():
158+
st.session_state.uploaded_file = None
159+
160+
st.title("QA Chatbot")
161+
162+
# Initialize session state
163+
if 'uploaded_file' not in st.session_state:
164+
st.session_state.uploaded_file = None
165+
166+
uploaded_files = st.file_uploader("Upload PDF or Document files", accept_multiple_files=True)
167+
168+
flag = False
169+
if uploaded_files:
170+
with st.spinner("Processing uploaded files..."):
171+
flag = process_uploaded_files(uploaded_files, data_store)
172+
173+
if flag:
174+
st.subheader("Chat Session")
175+
chain = qa_bot() # Initialize QA chain
176+
177+
query = st.chat_input("Ask your question...", key="query_chat_input")
178+
user = st.chat_message("user")
179+
ai = st.chat_message("ai")
180+
181+
if query:
182+
user.write(query)
183+
answer = final_result(chain, query)
184+
ai.write(answer["answer"])
185+
186+
if st.button('Reset'):
187+
st.session_state.clear()
188+
shutil.rmtree(data_store)
189+
shutil.rmtree(vectorstore_path)
190+
clear_uploaded_file()
191+
uploaded_files = None
192+
flag = False
193+
st.rerun()
194+
195+
if st.button('Stop'):
196+
st.stop()
197+
198+
if __name__ == "__main__":
199+
main()

datavectoriser.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import shutil
2+
import os
3+
import lancedb
4+
import pyarrow as pa
5+
# from langchain_experimental.text_splitter import SemanticChunker
6+
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
7+
from langchain.document_loaders.text import TextLoader
8+
from langchain.document_loaders.directory import DirectoryLoader
9+
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
10+
from langchain.vectorstores.lancedb import LanceDB
11+
12+
# Create a vector database
13+
def create_vector_db(data_path, vectorstore_path):
14+
try:
15+
# Check if user wants to override vectorstores with new data or not
16+
if os.path.exists(vectorstore_path):
17+
shutil.rmtree(vectorstore_path)
18+
19+
# Load PDF files as chunks of text using PyPDFLoader
20+
loader = DirectoryLoader(data_path, glob="*.txt", loader_cls=TextLoader, silent_errors=True, loader_kwargs={'encoding': 'utf-8-sig'}, use_multithreading=True)
21+
documents = loader.load() # Load documents as text chunks
22+
23+
# Create text embeddings; numerical vectors that represent the semantics of the text
24+
embedder = GoogleGenerativeAIEmbeddings(model = 'models/embedding-001')
25+
26+
# Split text chunks into smaller segments
27+
# text_splitter = SemanticChunker(embeddings=embedder)
28+
text_splitter = SentenceTransformersTokenTextSplitter()
29+
docs = text_splitter.split_documents(documents=documents)
30+
31+
db = lancedb.connect(vectorstore_path)
32+
33+
# Define schema (adjust vector size if needed)
34+
schema = pa.schema([
35+
pa.field("vector", pa.list_(pa.float32(), 768)),
36+
pa.field("text", pa.utf8()),
37+
pa.field("id", pa.utf8()),
38+
])
39+
40+
# Create table
41+
table = db.create_table("vector_table", schema=schema, mode="overwrite")
42+
43+
vectorstore = LanceDB(table, embedder)
44+
45+
vectorstore.add_documents(docs)
46+
47+
return True
48+
except Exception as e:
49+
return e
50+
51+
if __name__ == "__main__":
52+
create_vector_db()

0 commit comments

Comments
 (0)