Skip to content

Commit

Permalink
Update FM streaming RAG demo to use new NIMs (#234)
Browse files Browse the repository at this point in the history
* Update FM streaming RAG demo for DC AI Summit

- Uses new embedding and reranking NIMs
- Uses Parakeet ASR NIM
- Includes GNU Radio container
- Updates deployment scripts for deploying across multiple machines
- Various bug fixes and QoL updates

* Add diagramss for README

* Add README

* By default, use build.nvidia.com endpoints for non-ASR NIMs

* Restart Riva gRPC service on exceptions

* Add GNU Radio folder

* Update README to include frontend URL

* Fix LLM model options

* Fix bug when not using local Embedding / Reranking models
  • Loading branch information
dylan-eustice authored Nov 7, 2024
1 parent 63bb57f commit 75ccb42
Show file tree
Hide file tree
Showing 51 changed files with 914 additions and 2,131 deletions.
2 changes: 2 additions & 0 deletions community/fm-asr-streaming-rag/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
deploy/.keys
milvus/volumes
169 changes: 85 additions & 84 deletions community/fm-asr-streaming-rag/README.md

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions community/fm-asr-streaming-rag/chain-server/accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import datetime
import requests
from common import get_logger
from database import TimestampDatabase
from langchain.text_splitter import RecursiveCharacterTextSplitter

logger = get_logger(__name__)

FRONTEND_URI = os.environ.get('FRONTEND_URI', None)

#todo: Multi-thread to handle multiple concurrent streams
#todo: Add time-triggered embedding (i.e. embed after N seconds if no updates)
class TextAccumulator:
Expand All @@ -45,4 +50,14 @@ def update(self, source_id, text):
self.timestamp_db.insert_docs(new_docs, source_id)
self.db_interface.add_docs(new_docs, source_id)

for doc in new_docs:
endpoint = f"http://{FRONTEND_URI}/app/update_finalized_transcript"
time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
client_response = requests.post(endpoint, json={'transcript': f"[{time}] {doc}"})
logger.debug(f'Posted update_finalized_transcript: {client_response._content}')
logger.debug("--------------------------")
except requests.exceptions.ConnectionError:
logger.error(f"Failed to connect to the '{endpoint}' endpoint")

return {"status": f"Added {len(new_docs)} entries"}
5 changes: 2 additions & 3 deletions community/fm-asr-streaming-rag/chain-server/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from copy import copy
from datetime import datetime, timedelta
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.docstore.document import Document

from accumulator import TextAccumulator
from retriever import NemoRetrieverInterface, NvidiaApiInterface
from retriever import NVRetriever
from common import get_logger, LLMConfig, TimeResponse, UserIntent
from utils import get_llm, classify
from prompts import RAG_PROMPT, INTENT_PROMPT, RECENCY_PROMPT, SUMMARIZATION_PROMPT
Expand All @@ -36,7 +35,7 @@ def __init__(
self,
config: LLMConfig,
text_accumulator: TextAccumulator,
retv_interface: Union[NemoRetrieverInterface, NvidiaApiInterface]
retv_interface: NVRetriever
):
self.config = config
self.text_accumulator = text_accumulator
Expand Down
22 changes: 10 additions & 12 deletions community/fm-asr-streaming-rag/chain-server/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@
from typing import Literal
from langchain_community.utils.math import cosine_similarity

USE_NEMO_RETRIEVER = os.environ.get('USE_NEMO_RETRIEVER', 'False').lower() in ('true', '1')
NVIDIA_API_KEY = os.environ.get('NVIDIA_API_KEY', 'null')
NVIDIA_API_KEY = os.environ.get('NVIDIA_API_KEY', 'null')
LLM_URI = os.environ.get('LLM_URI', None)
RERANKING_MODEL = os.environ.get('RERANK_MODEL', None)
RERANKING_URI = os.environ.get('RERANK_URI', None)
EMBEDDING_MODEL = os.environ.get('EMBED_MODEL', None)
EMBEDDING_URI = os.environ.get('EMBED_URI', None)
MAX_DOCS = int(os.environ.get('MAX_DOCS_RETR', 8))

def get_logger(name):
LOG_LEVEL = logging.getLevelName(os.environ.get('CHAIN_LOG_LEVEL', 'WARN').upper())
Expand All @@ -51,7 +56,7 @@ class LLMConfig(BaseModel):
)
# Model choice
name: str = Field("Name of LLM instance to use")
engine: str = Field("Name of engine ['nvai-api-endpoint', 'triton-trt-llm']")
engine: str = Field("Name of engine ['nvai-api-endpoint', 'local-nim']")
# Chain parameters
use_knowledge_base: bool = Field(
description="Whether to use a knowledge base", default=True
Expand Down Expand Up @@ -95,11 +100,7 @@ def nvapi_embedding(text):
return embeddings

VALID_TIME_UNITS = ["seconds", "minutes", "hours", "days"]
TIME_VECTORS = None # Lazy loading in 'sanitize_time_unit'
if USE_NEMO_RETRIEVER:
embedding_service = nemo_embedding
else:
embedding_service = nvapi_embedding
TIME_VECTORS = nvapi_embedding(VALID_TIME_UNITS)

def sanitize_time_unit(time_unit):
"""
Expand All @@ -111,10 +112,7 @@ def sanitize_time_unit(time_unit):
if time_unit in VALID_TIME_UNITS:
return time_unit

if TIME_VECTORS is None:
TIME_VECTORS = embedding_service(VALID_TIME_UNITS)

unit_embedding = embedding_service([time_unit])
unit_embedding = nvapi_embedding([time_unit])
similarity = cosine_similarity(unit_embedding, TIME_VECTORS)
return VALID_TIME_UNITS[np.argmax(similarity)]

Expand Down
6 changes: 3 additions & 3 deletions community/fm-asr-streaming-rag/chain-server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import numpy as np

from common import get_logger
from datetime import datetime
from datetime import datetime, timedelta
from langchain.docstore.document import Document

logger = get_logger(__name__)
Expand Down Expand Up @@ -87,8 +87,8 @@ def recent(self, tstamp):
def past(self, tstamp, window=90):
""" Return entries within 'window' seconds of tstamp
"""
tstart = tstamp - datetime.timedelta(seconds=window)
tend = tstamp + datetime.timedelta(seconds=window)
tstart = tstamp - timedelta(seconds=window)
tend = tstamp + timedelta(seconds=window)
self.cursor.execute('SELECT * FROM messages WHERE timestamp BETWEEN ? AND ?', (tstart, tend))
docs = self.cursor.fetchall()
return [self.reformat(doc) for doc in docs]
10 changes: 8 additions & 2 deletions community/fm-asr-streaming-rag/chain-server/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def format_json(text: str):
recent timeframe. Examples: "Can you summarize the last hour of content?", "What have the main \
topics been over the last 5 minutes?", "Tell me the main stories of the past 2 hours.".
- 'TimeWindow': If the user is asking about the focus of the conversation from a specified time in \
the past. Examples: "What were they talking about 15 minutes ago?", "What was the focus an hour ago?".
the past. Examples: "What were they talking about 15 minutes ago?", "What was the focus an hour ago?", "What are they talking about right now?".
- If the user's intent is not clear, or if the intent cannot be confidently determined, classify \
this as 'Unknown'.
Expand Down Expand Up @@ -101,11 +101,17 @@ def format_json(text: str):
f"'{recency_examples[1]}' --> '{format_json(recency_examples_obj[1].model_dump_json())}'\n" +
f"'{recency_examples[2]}' --> '{format_json(recency_examples_obj[2].model_dump_json())}'\n" + """
Convert the user input below into this JSON format.
Convert the user input below into this JSON format. Make sure you use valid JSON and \
don't worry about escaping quotes, just give valid JSON blobs.
""")

SUMMARIZATION_PROMPT = """\
You are a sophisticated summarization tool designed to condense large blocks \
of text into a concise summary. Given the user text, reduce the character \
count by distilling into only the most important information.
Do not say you are summarizing, i.e. do not say "Here's a summary...", just \
condense the text to the best of your abilities.
Summary:
"""
21 changes: 3 additions & 18 deletions community/fm-asr-streaming-rag/chain-server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
fastapi==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6
unstructured[all-docs]==0.11.2
sentence-transformers==2.2.2
llama-index==0.9.22
pymilvus==2.3.5
dataclass-wizard==0.22.2
opencv-python==4.8.0.74
minio==7.2.0
asyncpg==0.29.0
psycopg2-binary==2.9.9
pgvector==0.2.4
langchain==0.1.14
langchain-core==0.1.40
langchain-nvidia-ai-endpoints==0.0.12
langchain-nvidia-trt==0.0.1rc0
nemollm==0.3.4
opentelemetry-sdk==1.21.0
opentelemetry-api==1.21.0
opentelemetry-exporter-otlp-proto-grpc==1.21.0
langchain==0.2.6
langchain_core==0.2.25
langchain_nvidia_ai_endpoints==0.2.0
Loading

0 comments on commit 75ccb42

Please sign in to comment.