Skip to content

Commit

Permalink
LlamaIndex Workflows with routing and multisource RAG (#222)
Browse files Browse the repository at this point in the history
* Multisource + routing initial version

* Running black

* Removing unnecessary prompt details

* Updating starters

* Remove unnecessary styling
  • Loading branch information
ckrapu-nv authored Oct 26, 2024
1 parent b2482e1 commit cd31614
Show file tree
Hide file tree
Showing 21 changed files with 946 additions and 0 deletions.
121 changes: 121 additions & 0 deletions community/routing-multisource-rag/.chainlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
[project]
# Whether to enable telemetry (default: true). No personal data is collected.
enable_telemetry = true


# List of environment variables to be provided by each user to use the app.
user_env = []

# Duration (in seconds) during which the session is saved when the connection is lost
session_timeout = 3600

# Enable third parties caching (e.g LangChain cache)
cache = false

# Authorized origins
allow_origins = ["*"]

# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
# follow_symlink = false

[features]
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
unsafe_allow_html = false

# Process and display mathematical expressions. This can clash with "$" characters in messages.
latex = false

# Automatically tag threads with the current chat profile (if a chat profile is used)
auto_tag_thread = true

# Allow users to edit their own messages
edit_message = true

# Authorize users to spontaneously upload files with messages
[features.spontaneous_file_upload]
enabled = true
accept = ["*/*"]
max_files = 20
max_size_mb = 500

[features.audio]
# Threshold for audio recording
min_decibels = -45
# Delay for the user to start speaking in MS
initial_silence_timeout = 3000
# Delay for the user to continue speaking in MS. If the user stops speaking for this duration, the recording will stop.
silence_timeout = 1500
# Above this duration (MS), the recording will forcefully stop.
max_duration = 15000
# Duration of the audio chunks in MS
chunk_duration = 1000
# Sample rate of the audio
sample_rate = 44100

[UI]
# Name of the assistant.
name = "Assistant"

# Description of the assistant. This is used for HTML tags.
# description = ""

# Large size content are by default collapsed for a cleaner ui
default_collapse_content = true

# Chain of Thought (CoT) display mode. Can be "hidden", "tool_call" or "full".
cot = "full"

# Link to your github repo. This will add a github button in the UI's header.
# github = ""

# Specify a CSS file that can be used to customize the user interface.
# The CSS file can be served from the public directory or via an external link.
# custom_css = "/public/test.css"

# Specify a Javascript file that can be used to customize the user interface.
# The Javascript file can be served from the public directory.
# custom_js = "/public/test.js"

# Specify a custom font url.
# custom_font = "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap"

# Specify a custom meta image url.
# custom_meta_image_url = "https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png"

# Specify a custom build directory for the frontend.
# This can be used to customize the frontend code.
# Be careful: If this is a relative path, it should not start with a slash.
# custom_build = "./public/build"

[UI.theme]
default = "dark"
#layout = "wide"
#font_family = "Inter, sans-serif"
# Override default MUI light theme. (Check theme.ts)
[UI.theme.light]
#background = "#FAFAFA"
#paper = "#FFFFFF"

[UI.theme.light.primary]
#main = "#F80061"
#dark = "#980039"
#light = "#FFE7EB"
[UI.theme.light.text]
#primary = "#212121"
#secondary = "#616161"

# Override default MUI dark theme. (Check theme.ts)
[UI.theme.dark]
#background = "#FAFAFA"
#paper = "#FFFFFF"

[UI.theme.dark.primary]
#main = "#F80061"
#dark = "#980039"
#light = "#FFE7EB"
[UI.theme.dark.text]
#primary = "#EEEEEE"
#secondary = "#BDBDBD"

[meta]
generated_by = "1.2.0"
7 changes: 7 additions & 0 deletions community/routing-multisource-rag/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.DS_Store
*.db
*.db.lock
.venv
.env

.chainlit/translations
37 changes: 37 additions & 0 deletions community/routing-multisource-rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# LlamaIndex + NVIDIA
This project shows how to use NVIDIA's APIs for large language models with LlamaIndex Workflow and ingestion functionality. Chainlit provides the chat UI.

### Project highlights:
- Interfaces with chat completion and embedding models from [build.nvidia.com](https://build.nvidia.com)
- Routs queries based on whether they require access to a document database
- Answers queries using the [Perplexity API for web search](https://docs.perplexity.ai/home)
- Performs vector lookup [using Milvus Lite](https://milvus.io/docs/milvus_lite.md) (WIP)
- Stores user chat history with a local SQLite database (WIP)

### Technologies used:
- **Frontend**: Chainlit
- **Web search**: Perplexity API
- **LLM**: Llama 3.1 8b and Mistral Large 2
- **Database**: Milvus Lite
- **Chat application**: LlamaIndex Workflows

![System architecture diagram](architecture.png)

### Getting started
To run this code, make sure you have environment variables set for the following:
- `NVIDIA_API_KEY` for access to NVIDIA LLM APIs (required). You can set this by running `export NVIDIA_API_KEY="nvapi-*******************"`. If you don't have an API key, follow [these instructions](https://github.com/NVIDIA/GenerativeAIExamples/blob/main/docs/api-catalog.md#get-an-api-key-for-the-accessing-models-on-the-api-catalog) to sign up for an NVIDIA AI Foundation developer account and obtain access.


- `PERPLEXITY_API_KEY` (optional) if you are interested in using Perplexity to answer queries using the web.

Then, clone this project and (optionally) create a new virtual environment in Python. Run `pip install -r requirements.txt` for the dependencies and begin the application using `chainlit run app.py` from this directory. The application should then be available at http://localhost:8000.

### Design
This project uses Chainlit to host a combined frontend and backend. The chat logic is implemented a LlamaIndex Workflow class, which runs the user's query through the following steps:
- Decide whether user query warrants usage of LLM with or without RAG (`QueryFlow.route_query`)
- If using RAG, the query is transformed (`QueryFlow.rewrite_query`)into a format better suited for web search and document retrieval and a vector embedding is produced (`QueryFlow.embed_query`)
- Documents are retrieved from the document database (`QueryFlow.milvus_retrieve`)
- An answer is solicited from the Perplexity API (`QueryFlow.pplx_retrieve`)
- The results are combined and used for generating a final response which is streamed to the user (`QueryFlow.synthesize_response`)

![Workflow diagram](diagram.png)
111 changes: 111 additions & 0 deletions community/routing-multisource-rag/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time

import chainlit as cl
from dotenv import load_dotenv
from llama_index.core import Settings

from workflow import QueryFlow

load_dotenv()

workflow = QueryFlow(timeout=45, verbose=False)


@cl.on_chat_start
async def on_chat_start():

cl.user_session.set("message_history", [])

workflow = QueryFlow(timeout=90, verbose=False)

cl.user_session.set("workflow", workflow)


@cl.set_starters
async def set_starters():
return [
cl.Starter(
label="Write a haiku about CPUs",
message="Write a haiku about CPUs.",
icon="/avatars/servers",
),
cl.Starter(
label="Write Docker Compose",
message="Write a Docker Compose file for deploying a web app with a Redis cache and Postgres database",
icon="/avatars/screen",
),
cl.Starter(
label="What NIMs are available?",
message="Summarize the different large language models that have NVIDIA inference microservices (NIMs) available for them. List as many as you can.",
icon="/avatars/container",
),
cl.Starter(
label="Summarize BioNemo use cases",
message="Write a table summarizing how customers are using bionemo. Use one sentence per customer and include columns for customer, industry, and use case. Make the table between 5 to 10 rows and relatively narrow.",
icon="/avatars/dna",
),
]


@cl.on_chat_end
def end():
logging.info("Chat ended.")


@cl.on_message
async def main(user_message: cl.Message, count_tokens: bool = True):
"""
Executes when a user sends a message. We send the message off to the LlamaIndex chat engine
for a streaming answer. When the answer is done streaming, we go back over the response
to identify the sources used, and then add a block of text about the sources.
"""

msg_start_time = time.time()
logging.info(f"Received message: <{user_message.content[0:50]}...> ")
message_history = cl.user_session.get("message_history", [])

# In case the chat workflow needs extra time to start up,
# we block until it's ready.

assistant_message = cl.Message(content="")

token_count = 0
with cl.Step(name="Mistral Large 2", type="tool"):

response, source_nodes = await workflow.run(
query=user_message.content,
chat_messages=message_history,
)

async for chunk in response:
token_count += 1
chars = chunk.delta
await assistant_message.stream_token(chars)

msg_time = time.time() - msg_start_time
logging.info(f"Message generated in {msg_time:.1f} seconds.")

message_history += [
{"role": "user", "content": user_message.content},
{"role": "assistant", "content": assistant_message.content},
]

cl.user_session.set("message_history", message_history)

await assistant_message.send()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions community/routing-multisource-rag/chainlit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Welcome to Chainlit! 🚀🤖

Hi there, Developer! 👋 We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.

## Useful Links 🔗

- **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) 📚
- **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! 💬

We can't wait to see what you create with Chainlit! Happy coding! 💻😊

## Welcome screen

To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.
87 changes: 87 additions & 0 deletions community/routing-multisource-rag/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from pydantic import BaseModel, Field


class WorkflowConfig(BaseModel):
perplexity_timeout: int = Field(
default=20, description="Timeout in seconds for Perplexity API call"
)
source_field_name: str = Field(
default="source_uri",
description="Field name for source URI in document metadata",
)
display_field_name: str = Field(
default="display_name",
description="Field name for display name in document metadata",
)
n_messages_in_history: int = Field(
default=6, description="Number of messages to include in chat history"
)
max_tokens_generated: int = Field(
default=1024, description="Maximum number of tokens to generate in response"
)
context_window: int = Field(
default=128_000, description="Size of the context window for the LLM"
)

chat_model_name: str = Field(
default="mistralai/mistral-large-2-instruct",
description="Model for final response synthesis. ",
)
routing_model_name: str = Field(
default="meta/llama-3.1-8b-instruct",
description="Model for performing query routing. Can be a bit dumber.",
)
perplexity_model_name: str = Field(
default="llama-3.1-sonar-large-128k-online",
description="Name of the Perplexity model; alternatives are huge and small.",
)
embedding_model_name: str = Field(
default="nvidia/nv-embed-v1", description="Name of the embedding model"
)
embedding_model_dim: int = Field(
default=4096, description="Dimension of the embedding model"
)
similarity_top_k: int = Field(
default=5,
description="Number of similar documents to return from vector search",
)

nvidia_api_key: str = Field(
default=os.getenv("NVIDIA_API_KEY"), description="NVIDIA API key"
)
perplexity_api_key: str = Field(
default=os.getenv("PERPLEXITY_API_KEY"),
description="Perplexity API key (optional)",
)

data_dir: str = Field(
default="data", description="Directory containing the documents to be indexed"
)
milvus_path: str = Field(
default="db/milvus_lite.db", description="Path to the Milvus database"
)

def __init__(self, **data):
super().__init__(**data)
if not self.nvidia_api_key:
raise ValueError("NVIDIA_API_KEY is required and must not be null")


config = WorkflowConfig()
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file added community/routing-multisource-rag/diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit cd31614

Please sign in to comment.