-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrag.py
More file actions
186 lines (157 loc) · 5.75 KB
/
rag.py
File metadata and controls
186 lines (157 loc) · 5.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
from dotenv import load_dotenv
import chromadb
import tiktoken
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
# ----------------------------
# Config
# ----------------------------
PERSIST_DIR = "chroma_db" # where your ChromaDB is stored
COLLECTION_NAME = "documents"
EMBED_MODEL = "all-MiniLM-L6-v2" # use the same one you used when indexing
GENAI_MODEL = "gemini-2.0-flash"
TOKEN_LIMIT = 3000 # adjust based on Gemini model limit
# ----------------------------
# Init Embedding + ChromaDB
# ----------------------------
print("Loading embedding model....")
embedder = SentenceTransformer(EMBED_MODEL)
print("Connecting to chromadb....")
client = chromadb.PersistentClient(path=PERSIST_DIR)
try:
collection = client.get_collection(COLLECTION_NAME)
except:
# Create collection if it doesn't exist
collection = client.create_collection(COLLECTION_NAME)
print(f"Created new collection: {COLLECTION_NAME}")
# ----------------------------
# Init Gemini LLM
# ----------------------------
print("Configuring gemini api....")
load_dotenv()
api_key=os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("Gemini api key not found")
genai.configure(api_key=api_key)
llm = genai.GenerativeModel(GENAI_MODEL)
# ----------------------------
# Token counter (for trimming context)
# ----------------------------
tokenizer = tiktoken.get_encoding("cl100k_base")
def trim_to_budget(text, limit=TOKEN_LIMIT):
tokens = tokenizer.encode(text)
if len(tokens) <= limit:
return text
trimmed = tokenizer.decode(tokens[:limit])
return trimmed
# ----------------------------
# RAG Pipeline
# ----------------------------
def rag_pipeline(user_question: str, top_k: int = 5):
# Check if collection has documents
try:
count = collection.count()
if count == 0:
return "No documents have been uploaded yet. Please upload a document first.", []
except:
return "Database not initialized. Please upload a document first.", []
# Step 2: Embed the query
print("Step1: Embedding user question...")
query_embedding = embedder.encode(user_question).tolist()
print("Step1 complete.")
# Step 3: Query ChromaDB
print("Step2: Querying chromadb...")
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k
)
print("Step2 complete.")
# Step 4: Assemble context
print("Step3: Assembling context....")
chunks = results["documents"][0]
metadatas = results["metadatas"][0]
distances = results["distances"][0]
context = "\n\n".join(chunks)
context = trim_to_budget(context, TOKEN_LIMIT)
print("Step3 complete.")
# Step 5: Build prompt
print("Step4: Building prompt")
prompt = f"""
You are a factual RAG assistant.
You will receive:
1.A user question .
2.The top 5 most relevant chunks from the reference file.
Instructions:
1.Carefully read all the provided chunks before answering.
2.Identify the one chunk that best answers the question:
-Use only that chunk's content as the core of your answer.
-You may rephrase or enrich the answer **so it is clear,coherent,and medium in length**(not too short ,not too long).
-Provide a answer that **explains the main points** and is **not too long but should main points**.
-Do not give a very short or vague answer.
-Make sure the answer is **directly related to the question**.
-If none of the chunks directly answer the question,reply excatly:
"The answer is not present in the provided file."
Context:
{context}
Question:
{user_question}
"""
print("Step4 complete.")
# Step 6: Call Gemini
print("Step5: Calling gemini llm....")
response = llm.generate_content(prompt)
print("Step5 complete")
# Step 7: Attach citations (only if answer found)
print("Step6: Attaching citations...")
citations = []
# Check if answer indicates no match found
if "The answer is not present in the provided file." in response.text:
citations = [] # No citations for "not found" responses
else:
for i, meta in enumerate(metadatas):
citations.append({
"file_path": meta.get("file_path", "unknown"),
"chunk_index": meta.get("chunk_index", i),
"distance": distances[i]
})
print("Step6 complete.")
return response.text, citations
# ----------------------------
# Run Example
# ----------------------------
"""if __name__ == "__main__":
print("\n Welcome to RAG Phase2 with Gemini")
question = input("Enter the question :")
answer, cites = rag_pipeline(question)
print("\n--- Answer ---")
print(answer)
print("\n--- Citations ---")
for c in cites:
print(c)"""
# ----------------------------
# Run Example
# ----------------------------
if __name__ == "__main__":
print("\nWelcome to RAG Phase2 with Gemini!")
# 1. Ask first if they want to start
continue_asking = input("Do you want to ask a question? (yes/no): ").lower().strip()
# 2. Loop until user says no
while continue_asking in ("yes", "y"):
print("\n---------------------------------------------------")
# Get the question
question = input("Enter your question (or type 'exit' to quit): ").strip()
# If user types 'exit', break out of the loop
if question.lower() in ("exit", "quit"):
break
# Run pipeline
answer, cites = rag_pipeline(question)
# Show answer
print("\n--- Answer ---")
print(answer)
print("\n--- Citations ---")
for c in cites:
print(str(c))
# Ask again
continue_asking = input("\nDo you want to ask another question? (yes/no): ").lower().strip()
print("\nExiting RAG assistant. Goodbye! ")