Skip to content
This repository has been archived by the owner on Jan 5, 2025. It is now read-only.

Commit

Permalink
Merge pull request #239 from lvalics/main
Browse files Browse the repository at this point in the history
Ollama LLM and conversational retrieval...
  • Loading branch information
lvalics authored Feb 18, 2024
2 parents cc35064 + 6ef2df6 commit b0b49eb
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 101 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ dj_backend_server/nginx/nginx.conf
dj_backend_server.code-workspace
.aider*
.aiderignore
dj_backend_server/.vscode/settings.json

3 changes: 2 additions & 1 deletion dj_backend_server/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ pip-delete-this-directory.txt
website_data_sources/*
venv
open-llama-7B-open-instruct.ggmlv3.q4_K_M.bin
llama-2-7b-chat.ggmlv3.q4_K_M.bin
llama-2-7b-chat.ggmlv3.q4_K_M.bin
.vscode/
4 changes: 4 additions & 0 deletions dj_backend_server/CHANGELOG.MD
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
2.18.2024
- The conversational retrieval functionality is now operating as expected. It successfully sends the conversation history to the language model, allowing the context from previous interactions to be utilized effectively.
- Added support for Ollama as the Language Model (LLM). Ensure Ollama is specified in the .env configuration and the model is preloaded on the server.

2.17.2024
- Incorporate 'Ollama' into your example.env configuration and make sure to reflect these changes in your .env file for compatibility.
- We've expanded the logging capabilities within settings.py by deploying logging.debug for more detailed insights, although it remains inactive when the DEBUG mode is off.
Expand Down
18 changes: 10 additions & 8 deletions dj_backend_server/api/utils/get_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ def get_azure_embedding():
deployment = os.environ.get("AZURE_OPENAI_EMBEDDING_MODEL_NAME")
openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
client = os.environ.get("AZURE_OPENAI_API_TYPE")
openai_api_base = os.environ['AZURE_OPENAI_API_BASE']
openai_api_version = os.environ['AZURE_OPENAI_API_VERSION']
openai_api_base = os.environ["AZURE_OPENAI_API_BASE"]
openai_api_version = os.environ["AZURE_OPENAI_API_VERSION"]

return OpenAIEmbeddings(
openai_api_key=openai_api_key,
deployment=deployment,
client=client,
chunk_size=8,
openai_api_base=openai_api_base,
openai_api_version=openai_api_version
openai_api_version=openai_api_version,
)


def get_openai_embedding():
"""Gets embeddings using the OpenAI embedding provider."""
openai_api_key = os.environ.get("OPENAI_API_KEY")
return OpenAIEmbeddings(openai_api_key=openai_api_key, chunk_size=1)
return OpenAIEmbeddings(openai_api_key=openai_api_key, chunk_size=1)


def get_llama2_embedding():
Expand All @@ -48,15 +48,17 @@ def choose_embedding_provider():

if embedding_provider == EmbeddingProvider.azure.value:
return get_azure_embedding()

elif embedding_provider == EmbeddingProvider.OPENAI.value:
return get_openai_embedding()

elif embedding_provider == EmbeddingProvider.llama2.value:
return get_llama2_embedding()

else:
available_providers = ", ".join([service.value for service in EmbeddingProvider])
available_providers = ", ".join(
[service.value for service in EmbeddingProvider]
)
raise ValueError(
f"Embedding service '{embedding_provider}' is not currently available. "
f"Available services: {available_providers}"
Expand All @@ -66,4 +68,4 @@ def choose_embedding_provider():
# Main function to get embeddings
def get_embeddings() -> Embeddings:
"""Gets embeddings using the chosen embedding provider."""
return choose_embedding_provider()
return choose_embedding_provider()
62 changes: 23 additions & 39 deletions dj_backend_server/api/utils/get_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
from django.utils.timezone import make_aware
from datetime import datetime, timezone
from uuid import uuid4
from ollama import Client
from openai import OpenAI
from django.conf import settings
from langchain_openai.chat_models import ChatOpenAI
from langchain_community.llms import Ollama
from langchain_community.chat_models import ChatOllama
from langchain_community.llms import AzureOpenAI
from langchain_community.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from web.models.failed_jobs import FailedJob
Expand Down Expand Up @@ -62,12 +58,7 @@ def get_llama_llm():
def get_azure_openai_llm():
"""Returns AzureOpenAI instance configured from environment variables"""
try:
if settings.DEBUG:
openai_api_type = "openai" # JUST FOR DEVELOPMENT
logging.debug(f"DEVELOPMENT Using API Type: {openai_api_type}")
else:
openai_api_type = os.environ["AZURE_OPENAI_API_TYPE"]

openai_api_type = os.environ["AZURE_OPENAI_API_TYPE"]
openai_api_key = os.environ["AZURE_OPENAI_API_KEY"]
openai_deployment_name = os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"]
openai_model_name = os.environ["AZURE_OPENAI_COMPLETION_MODEL"]
Expand Down Expand Up @@ -134,30 +125,26 @@ def get_openai_llm():
traceback.print_exc()


def get_ollama_llm(sanitized_question):
"""Returns an Ollama Server instance configured from environment variables"""
llm = Client(host=os.environ.get("OLLAMA_URL"))
# Use the client to make a request
def get_ollama_llm():
"""Returns an Ollama instance configured from environment variables"""
try:
if sanitized_question:
response = llm.chat(
model=os.environ.get("OLLAMA_MODEL_NAME"),
messages=[{"role": "user", "content": sanitized_question}],
)
else:
raise ValueError("Question cannot be None.")
if response:
return response
else:
raise ValueError("Invalid response from Ollama.")
base_url = os.environ.get("OLLAMA_URL")
model = os.environ.get("OLLAMA_MODEL_NAME", "llama2")

llm = ChatOllama(
base_url=base_url,
model=model,
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
)
return llm

except Exception as e:
logger.debug(f"Exception in get_ollama_llm: {e}")
failed_job = FailedJob(
uuid=str(uuid4()),
connection="default",
queue="default",
payload="get_openai_llm",
payload="get_ollama_llm",
exception=str(e),
failed_at=make_aware(datetime.now(), timezone.utc),
)
Expand All @@ -176,29 +163,26 @@ def get_llm():
"ollama": lambda: get_ollama_llm(),
}

# DEVENV
# if settings.DEBUG:
# api_type = "ollama"
api_type = os.environ.get("OPENAI_API_TYPE", "openai")

if api_type not in clients:
raise ValueError(f"Invalid OPENAI_API_TYPE: {api_type}")

logging.debug(f"Using LLM: {api_type}")

if api_type in clients:
if api_type == "ollama":
return clients[api_type]()
elif api_type != "ollama":
return clients[api_type]()
llm_instance = clients[api_type]()
if llm_instance is None:
logger.error(f"LLM instance for {api_type} could not be created.")
return None
return llm_instance
else:
raise ValueError(f"Invalid OPENAI_API_TYPE: {api_type}")

except Exception as e:
failed_job = FailedJob(
uuid=str(uuid4()),
connection="default",
queue="default",
payload="get_llm",
exception=str(e),
failed_at=datetime.now(),
)
failed_job = FailedJob(
uuid=str(uuid4()),
connection="default",
Expand Down
1 change: 1 addition & 0 deletions dj_backend_server/api/utils/make_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def getConversationRetrievalChain(
retriever=vector_store.as_retriever(),
verbose=True,
combine_docs_chain_kwargs={"prompt": prompt},
return_source_documents=True,
)
logger.debug(f"ConversationalRetrievalChain {llm}, created: {chain}")
return chain
Expand Down
14 changes: 4 additions & 10 deletions dj_backend_server/api/views/views_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,15 @@ def get_completion_response(
elif chain_type == "conversation_retrieval":
chain = getConversationRetrievalChain(vector_store, mode, initial_prompt)
logger.debug("getConversationRetrievalChain")
chat_history_json = json.dumps(
get_chat_history_for_retrieval_chain(
session_id, limit=20, initial_prompt=initial_prompt
),
ensure_ascii=False,
chat_history = get_chat_history_for_retrieval_chain(
session_id, limit=20, initial_prompt=initial_prompt
)
chat_history_json = ""
logger.debug(f"Formatted Chat_history {chat_history_json}")
logger.debug(f"Formatted Chat_history {chat_history}")

response = chain.invoke(
{"question": sanitized_question, "chat_history": chat_history_json}
{"question": sanitized_question, "chat_history": chat_history},
)
logger.debug(f"response from chain.invoke: {response}")
response_text = response.get("answer")
logger.debug(f"response_text : {response_text}")
try:
# Attempt to parse the response_text as JSON
response_text = json.loads(response_text)
Expand Down
32 changes: 13 additions & 19 deletions dj_backend_server/api/views/views_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,26 +170,23 @@ def send_chat(request):
"""
try:

if settings.DEBUG:
logger.debug("Entering send_chat function")
logger.debug("Entering send_chat function")
# You can add additional validation for 'history' and 'content_type' if needed.

bot_token = request.headers.get("X-Bot-Token")
bot = get_object_or_404(Chatbot, token=bot_token)

data = json.loads(request.body)
if settings.DEBUG:
logger.debug(
f"Request data: {data}"
) # {'from': 'user', 'type': 'text', 'content': 'input text from chat'}
logger.debug(
f"Request data: {data}"
) # {'from': 'user', 'type': 'text', 'content': 'input text from chat'}
# Validate the request data
content = data.get("content")
history = data.get("history")
if settings.DEBUG:
logger.debug(f"Content: {content}")
logger.debug(
f"History: {history}"
) # history is a list of chat history - None????
logger.debug(f"Content: {content}")
logger.debug(
f"History: {history}"
) # history is a list of chat history - None????
content_type = data.get("type")

session_id = get_session_id(request=request, bot_id=bot.id)
Expand All @@ -198,10 +195,9 @@ def send_chat(request):
{"message": entry.message, "from_user": entry.from_user}
for entry in history
]
if settings.DEBUG:
logger.debug(
f"History entries in JSON: {history_entries} - and history in text from DB: {history}"
)
logger.debug(
f"History entries in JSON: {history_entries} - and history in text from DB: {history}"
)

# Implement the equivalent logic for validation
if not content:
Expand All @@ -211,8 +207,7 @@ def send_chat(request):
)

# Implement the equivalent logic to send the HTTP request to the external API
if settings.DEBUG:
logger.debug(f"External API response START")
logger.debug(f"External API response START")
response = requests.post(
os.getenv("APP_URL") + "/api/chat/",
json={
Expand All @@ -226,8 +221,7 @@ def send_chat(request):
},
timeout=200,
)
if settings.DEBUG:
logger.debug(f"External API response: {response.text} and {response}")
logger.debug(f"External API response: {response.text} and {response}")

"""
This block will first check if the response content is not empty. If it is empty,
Expand Down
11 changes: 4 additions & 7 deletions dj_backend_server/example.env
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ OPENAI_API_TYPE=openai
OPENAI_API_MODEL=gpt-4-1106-preview
OPENAI_API_TEMPERATURE=1

# azure | openai | llama2 | ollama
# azure | openai | llama2 - change only if you know what you do
EMBEDDING_PROVIDER=openai

# If using azure
Expand All @@ -30,22 +30,20 @@ EMBEDDING_PROVIDER=openai
# AZURE_OPENAI_DEPLOYMENT_NAME=
# AZURE_OPENAI_COMPLETION_MODEL=gpt-35-turbo


# OLLAMA_URL="" #no trailing slash at the end or will not work.
# OLLAMA_MODEL_NAME="" # ex openchat, llama2 - Be sure you have this on server downloaded "ollama pull openchat"

# Vector Store, PINECONE|QDRANT
STORE=QDRANT


# if using pinecone
# PINECONE_API_KEY=
# PINECONE_ENV=
# VECTOR_STORE_INDEX_NAME=


# if using qdrant
QDRANT_URL=http://qdrant:6333


# optional, defaults to 15
MAX_PAGES_CRAWL=150

Expand Down Expand Up @@ -73,5 +71,4 @@ OCR_LLM = '1'

# retrieval_qa | conversation_retrieval, retrieval_qa works better with azure openai
# if you want to use the conversation_retrieval | retrieval_qa chain
CHAIN_TYPE=conversation_retrieval

CHAIN_TYPE=conversation_retrieval
7 changes: 6 additions & 1 deletion dj_backend_server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ drf-spectacular==0.27.1
drf_spectacular.extensions==0.0.2
exceptiongroup==1.1.2
frozenlist==1.4.0
filelock==3.13.1
fsspec==2024.2.0
huggingface-hub==0.20.3
grpcio==1.56.2
grpcio-tools==1.56.2
h11==0.14.0
Expand Down Expand Up @@ -71,6 +74,7 @@ qdrant-client==1.7.0
redis==4.6.0
regex==2023.6.3
requests==2.31.0
safetensors==0.4.2
six==1.16.0
sniffio==1.3.0
soupsieve==2.4.1
Expand All @@ -79,6 +83,8 @@ sqlparse==0.4.4
tenacity==8.2.2
tiktoken==0.6.0
tqdm==4.65.0
tokenizers==0.15.2
transformers==4.37.2
typing-inspect==0.9.0
typing_extensions==4.7.1
tzdata==2023.3
Expand All @@ -88,4 +94,3 @@ wcwidth==0.2.6
yarl==1.9.2
django-cors-headers==4.3.1


Loading

0 comments on commit b0b49eb

Please sign in to comment.