Skip to content

Commit

Permalink
Merge pull request #77 from Renumics/new-api
Browse files Browse the repository at this point in the history
New API version with Todos
  • Loading branch information
SYoy authored Mar 3, 2025
2 parents 0595079 + e94a6a5 commit 24a67ad
Show file tree
Hide file tree
Showing 61 changed files with 16,033 additions and 14,430 deletions.
13 changes: 13 additions & 0 deletions TODOS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Open tasks before release of new API version

- [ ] Style new conversion buttin in Chatwindow
- [ ] Docs: RagProvider / state management
- [ ] Docs: Hooks
- [ ] Docs: Update to API 2.0 (new features, user-action flow)
- [ ] Python API: Update to API 2.0
- [ ] (Optional) add user action modifiers to current set of operations (slack canvas)
- [ ] activeSources -> [], null, or set of retrievedSources
- [ ] Document .data attribute and how to omit it from backend calls to avoid timeouts
- [ ] Document Hooks (docstrings)
- [ ] Look into type inference (auto complete) for ActionHandlerResponses
- [ ] Refactor docs 'LexioProvider'
252 changes: 95 additions & 157 deletions examples/advanced-local-rag/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# Load model & tokenizer
model_name = "Qwen/Qwen2.5-7B-Instruct"
model_name = "Qwen/Qwen2.5-7B-Instruct" if device == "cuda" else "HuggingFaceTB/SmolLM2-360M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True).to(device)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
load_in_4bit=True if device == "cuda" else False,
bnb_4bit_compute_dtype=torch.float16 # Set compute dtype to float16
).to(device)

class Message(BaseModel):
role: str
Expand Down Expand Up @@ -74,7 +79,10 @@ def run_generation():
inputs=inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=False # set to True if you'd like more "interactive" sampling
do_sample=False,
top_p=None,
top_k=None,
temperature=None
)
finally:
# Ensure streamer is properly ended even if generation fails
Expand All @@ -87,42 +95,90 @@ def run_generation():
# 5) Return the streamer to the caller
return streamer

class RetrieveAndGenerateRequest(BaseModel):
query: str

@app.post("/api/retrieve-and-generate")
async def retrieve_and_generate(request: RetrieveAndGenerateRequest):
@app.get("/pdfs/{id}")
async def get_pdf(id: str):
"""
SSE endpoint that:
1) Retrieves relevant chunks from your DB
2) Yields a JSON with 'sources' first
3) Streams tokens from the model as they are generated (real time)
Endpoint to serve document content by looking up the ID in the database.
Handles both binary (PDF) and text-based (HTML, Markdown) content.
"""
query = request.query
# First look up the document path using the ID from the database
table = db_utils.get_table()
results = table.search().where(f"id = '{id}'", prefilter=True).to_list()

if not results:
raise HTTPException(status_code=404, detail="Document ID not found")

# Get the document path from the first result
doc_path = results[0]['doc_path']

if not os.path.exists(doc_path):
raise HTTPException(status_code=404, detail="File not found on disk")

# Determine content type based on file extension
if doc_path.endswith('.pdf'):
return FileResponse(
doc_path,
media_type='application/pdf',
filename=os.path.basename(doc_path)
)
elif doc_path.endswith('.html'):
return FileResponse(
doc_path,
media_type='text/html',
filename=os.path.basename(doc_path)
)
else: # Markdown or other text files
return FileResponse(
doc_path,
media_type='text/plain',
filename=os.path.basename(doc_path)
)

class ChatRequest(BaseModel):
messages: List[Message]
source_ids: Optional[List[str]] = None

@app.post("/api/chat")
async def chat_endpoint(request: ChatRequest):
"""
Unified SSE endpoint that handles both initial queries and follow-ups:
- messages: list of chat messages (required)
- source_ids: optional list of specific source IDs to use
If source_ids are provided, those specific sources will be used as context
If no source_ids are provided, the system will automatically retrieve relevant sources
based on the latest user query
"""
print("Request received:", request)
try:
# Start timing
start_time = time.time()
messages_list = request.messages
print(f"Chat history length: {len(messages_list)}")
print("Message roles:", [msg.role for msg in messages_list])

context_str = ""
sources = []

# 1) Time the embedding generation
embed_start = time.time()
query_embedding = db_utils.get_model().encode(query)
embed_time = time.time() - embed_start
print(f"Embedding generation took: {embed_time:.2f} seconds")
# Get the latest user message as the query for retrieval
latest_query = next((msg.content for msg in reversed(messages_list) if msg.role == "user"), None)

# 2) Time the database search
search_start = time.time()
table = db_utils.get_table()
results = (
table.search(query=query_embedding, vector_column_name="embedding")
.limit(5)
.to_list()
)

search_time = time.time() - search_start
print(f"Database search took: {search_time:.2f} seconds")

# 3) Time the sources processing
process_start = time.time()
if request.source_ids:
# Use specified sources if provided
print(f"Using provided source IDs: {request.source_ids}")
source_ids_str = "('" + "','".join(request.source_ids) + "')"
results = table.search().where(f"id in {source_ids_str}", prefilter=True).to_list()
else:
# Otherwise perform semantic search based on the latest query
print(f"Performing semantic search for: {latest_query}")
query_embedding = db_utils.get_model().encode(latest_query)
results = (
table.search(query=query_embedding, vector_column_name="embedding")
.limit(5)
.to_list()
)

# Process results into sources and context
sources = [
{
"doc_path": r["doc_path"],
Expand All @@ -142,160 +198,42 @@ async def retrieve_and_generate(request: RetrieveAndGenerateRequest):
}
for r in results
]
process_time = time.time() - process_start
print(f"Processing results took: {process_time:.2f} seconds")

# Log total preparation time
total_prep_time = time.time() - start_time
print(f"Total preparation time: {total_prep_time:.2f} seconds")

# 4) Build context

context_str = "\n\n".join([
f"[Document: {r['doc_path']}]\n{r['text']}"
for r in results
])
messages = [Message(role="user", content=query)]

# 5) Create async generator to yield SSE
async def event_generator():
try:
# First yield the sources
yield {"data": json.dumps({"sources": sources})}

# Now create the streamer & generate tokens
streamer = generate_stream(messages, context_str)

# For each partial token, yield SSE data
for token in streamer:
if token: # Only send if token is not empty
try:
data = json.dumps({"content": token, "done": False})
yield {"data": data}
await asyncio.sleep(0) # let the event loop flush data
except Exception as e:
print(f"Error during token streaming: {str(e)}")
continue

# Finally, yield "done"
yield {"data": json.dumps({"content": "", "done": True})}
except Exception as e:
print(f"Error in event generator: {str(e)}")
yield {"data": json.dumps({"error": str(e)})}

# 6) Return SSE
return EventSourceResponse(event_generator())

except Exception as e:
return {"error": str(e)}

class GenerateRequest(BaseModel):
messages: List[Message]
source_ids: Optional[List[str]] = None

@app.post("/api/generate")
async def generate_endpoint(request: GenerateRequest):
"""
SSE endpoint for follow-up requests using:
- messages: list of previous chat messages
- source_ids: optional list of doc references to build context
Streams the model response in real time.
"""
try:
# Log message history length and content
messages_list = request.messages
print(f"Chat history length: {len(messages_list)}")
print("Message roles:", [msg.role for msg in messages_list])

# Log source usage
source_ids_list = request.source_ids
print(f"Using source IDs: {source_ids_list if source_ids_list else 'No sources'}")

# 1) Build context from source IDs (if provided)
context_str = ""
if source_ids_list:
table = db_utils.get_table()
source_ids_str = "('" + "','".join(source_ids_list) + "')"
chunks = table.search().where(f"id in {source_ids_str}", prefilter=True).to_list()

# Log retrieved chunks info
print(f"Retrieved {len(chunks)} chunks from database")
for chunk in chunks:
print(f"Document: {chunk['doc_path']}")

context_str = "\n\n".join([
f"[Document: {chunk['doc_path']}]\n{chunk['text']}"
for chunk in chunks
])
print(f"Total context length: {len(context_str)} characters")

# 2) Build async generator for SSE
async def event_generator():
try:
# Create the streamer
# First yield the sources if we have any
if sources:
yield {"data": json.dumps({"sources": sources})}

# Create the streamer & generate tokens
streamer = generate_stream(messages_list, context_str)

# For each partial token, yield SSE data
for token in streamer:
if token: # Only send if token is not empty
if token:
try:
data = json.dumps({"content": token, "done": False})
yield {"data": data}
await asyncio.sleep(0) # yield control so data can flush
await asyncio.sleep(0)
except Exception as e:
print(f"Error during token streaming: {str(e)}")
continue

# Finally, yield "done"
yield {"data": json.dumps({"content": "", "done": True})}
except Exception as e:
print(f"Error in event generator: {str(e)}")
yield {"data": json.dumps({"error": str(e)})}

# 3) Return SSE
return EventSourceResponse(event_generator())

except Exception as e:
return {"error": str(e)}

@app.get("/pdfs/{id}")
async def get_pdf(id: str):
"""
Endpoint to serve document content by looking up the ID in the database.
Handles both binary (PDF) and text-based (HTML, Markdown) content.
"""
# First look up the document path using the ID from the database
table = db_utils.get_table()
results = table.search().where(f"id = '{id}'", prefilter=True).to_list()

if not results:
raise HTTPException(status_code=404, detail="Document ID not found")

# Get the document path from the first result
doc_path = results[0]['doc_path']

if not os.path.exists(doc_path):
raise HTTPException(status_code=404, detail="File not found on disk")

# Determine content type based on file extension
if doc_path.endswith('.pdf'):
return FileResponse(
doc_path,
media_type='application/pdf',
filename=os.path.basename(doc_path)
)
elif doc_path.endswith('.html'):
return FileResponse(
doc_path,
media_type='text/html',
filename=os.path.basename(doc_path)
)
else: # Markdown or other text files
return FileResponse(
doc_path,
media_type='text/plain',
filename=os.path.basename(doc_path)
)

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Loading

0 comments on commit 24a67ad

Please sign in to comment.