Skip to content

Commit

Permalink
Fix time unit embedding generation (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylan-eustice authored Nov 29, 2024
1 parent 93c2c84 commit c36d09d
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 81 deletions.
4 changes: 2 additions & 2 deletions community/fm-asr-streaming-rag/chain-server/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from accumulator import TextAccumulator
from retriever import NVRetriever
from common import get_logger, LLMConfig, TimeResponse, UserIntent
from utils import get_llm, classify
from common import get_logger, LLMConfig
from utils import get_llm, classify, TimeResponse, UserIntent
from prompts import RAG_PROMPT, INTENT_PROMPT, RECENCY_PROMPT, SUMMARIZATION_PROMPT

logger = get_logger(__name__)
Expand Down
78 changes: 1 addition & 77 deletions community/fm-asr-streaming-rag/chain-server/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@
# limitations under the License.

import logging
import requests
import json
import os
import numpy as np

from datetime import datetime, timedelta
from datetime import datetime
from pydantic import BaseModel, Field
from typing import Literal
from langchain_community.utils.math import cosine_similarity

NVIDIA_API_KEY = os.environ.get('NVIDIA_API_KEY', 'null')
LLM_URI = os.environ.get('LLM_URI', None)
Expand Down Expand Up @@ -65,74 +60,3 @@ class LLMConfig(BaseModel):
temperature: float = Field("Temperature of the LLM response")
max_docs: int = Field("Maximum number of documents to return")
num_tokens: int = Field("The maximum number of tokens in the response")

def nemo_embedding(text):
"""
Uses the NeMo Embedding MS to convert text to embeddings
- ex: embeddings = nemo_embedding(['Chunk A', 'Chunk B'])
"""
port = os.environ.get('NEMO_EMBEDDING_PORT', 1985)
url = f"http://localhost:{port}/v1/embeddings"
payload = json.dumps({
"input": text,
"model": "NV-Embed-QA",
"input_type": "query"
})
headers = {'Content-Type': 'application/json'}
response = requests.request("POST", url, headers=headers, data=payload)
embeddings = [chunk['embedding'] for chunk in response.json()['data']]
return embeddings

def nvapi_embedding(text):
session = requests.Session()
url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
headers = {
"Authorization": f"Bearer {NVIDIA_API_KEY}",
"Accept": "application/json",
}
payload = {
"input": text,
"input_type": "passage",
"model": "NV-Embed-QA"
}
response = session.post(url, headers=headers, json=payload)
embeddings = [chunk['embedding'] for chunk in response.json()['data']]
return embeddings

VALID_TIME_UNITS = ["seconds", "minutes", "hours", "days"]
TIME_VECTORS = nvapi_embedding(VALID_TIME_UNITS)

def sanitize_time_unit(time_unit):
"""
For cases where an LLM returns a time unit that doesn't match one of the
discrete options, find the closest with cosine similarity.
Example: 'min' -> 'minutes'
"""
if time_unit in VALID_TIME_UNITS:
return time_unit

unit_embedding = nvapi_embedding([time_unit])
similarity = cosine_similarity(unit_embedding, TIME_VECTORS)
return VALID_TIME_UNITS[np.argmax(similarity)]

"""
Pydantic classes that are used to detect user intent and plan accordingly
"""
class TimeResponse(BaseModel):
timeNum: float = Field("The number of time units the user asked about")
timeUnit: str = Field("The unit of time the user asked about")

def to_seconds(self):
""" Return the total number of seconds this represents
"""
self.timeUnit = sanitize_time_unit(self.timeUnit)
return timedelta(**{self.timeUnit: self.timeNum}).total_seconds()

class UserIntent(BaseModel):
intentType: Literal[
"SpecificTopic",
"RecentSummary",
"TimeWindow",
"Unknown"
] = Field("The intent of user's query")
2 changes: 1 addition & 1 deletion community/fm-asr-streaming-rag/chain-server/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

from pydantic import BaseModel
from common import UserIntent, TimeResponse
from utils import UserIntent, TimeResponse

def format_schema(pydantic_obj: BaseModel):
return str(
Expand Down
46 changes: 45 additions & 1 deletion community/fm-asr-streaming-rag/chain-server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@

import json
import re
import json
import numpy as np

from datetime import timedelta
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank
from pydantic import BaseModel
from pydantic import BaseModel, Field
from langchain_community.utils.math import cosine_similarity
from typing import Literal

from common import (
get_logger,
Expand Down Expand Up @@ -97,6 +102,45 @@ def get_embedder(local: bool=True):
truncate="NONE"
)

embed_client = get_embedder()
VALID_TIME_UNITS = ["seconds", "minutes", "hours", "days"]
TIME_VECTORS = embed_client.embed_documents(VALID_TIME_UNITS)

def sanitize_time_unit(time_unit):
"""
For cases where an LLM returns a time unit that doesn't match one of the
discrete options, find the closest with cosine similarity.
Example: 'min' -> 'minutes'
"""
if time_unit in VALID_TIME_UNITS:
return time_unit

unit_embedding = embed_client.embed_documents([time_unit])
similarity = cosine_similarity(unit_embedding, TIME_VECTORS)
return VALID_TIME_UNITS[np.argmax(similarity)]

"""
Pydantic classes that are used to detect user intent and plan accordingly
"""
class TimeResponse(BaseModel):
timeNum: float = Field("The number of time units the user asked about")
timeUnit: str = Field("The unit of time the user asked about")

def to_seconds(self):
""" Return the total number of seconds this represents
"""
self.timeUnit = sanitize_time_unit(self.timeUnit)
return timedelta(**{self.timeUnit: self.timeNum}).total_seconds()

class UserIntent(BaseModel):
intentType: Literal[
"SpecificTopic",
"RecentSummary",
"TimeWindow",
"Unknown"
] = Field("The intent of user's query")

def classify(question, chain, pydantic_obj: BaseModel):
""" Parse a question into structured pydantic_obj
"""
Expand Down

0 comments on commit c36d09d

Please sign in to comment.