Skip to content

Commit

Permalink
feat: Include images in call to LLM (#971)
Browse files Browse the repository at this point in the history
Co-authored-by: Frances Tibble <[email protected]>
  • Loading branch information
cecheta and frtibble authored May 24, 2024
1 parent 59f55ea commit 5617f82
Show file tree
Hide file tree
Showing 9 changed files with 315 additions and 45 deletions.
62 changes: 56 additions & 6 deletions code/backend/batch/utilities/tools/question_answer_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

from ..common.answer import Answer
from ..common.source_document import SourceDocument
from ..helpers.azure_blob_storage_client import AzureBlobStorageClient
from ..helpers.config.config_helper import ConfigHelper
from ..helpers.env_helper import EnvHelper
from ..helpers.llm_helper import LLMHelper
from ..search.search import Search
from .answering_tool_base import AnsweringToolBase
from openai.types.chat import ChatCompletion

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,6 +64,7 @@ def generate_on_your_data_messages(
question: str,
chat_history: list[dict],
sources: list[SourceDocument],
image_urls: list[str] = [],
) -> list[dict]:
examples = []

Expand Down Expand Up @@ -122,20 +125,41 @@ def generate_on_your_data_messages(
},
*QuestionAnswerTool.clean_chat_history(chat_history),
{
"content": self.config.prompts.answering_user_prompt.format(
sources=documents,
question=question,
),
"content": [
{
"type": "text",
"text": self.config.prompts.answering_user_prompt.format(
sources=documents,
question=question,
),
},
*(
[
{
"type": "image_url",
"image_url": image_url,
}
for image_url in image_urls
]
),
],
"role": "user",
},
]

def answer_question(self, question: str, chat_history: list[dict], **kwargs):
source_documents = Search.get_source_documents(self.search_handler, question)

if self.env_helper.USE_ADVANCED_IMAGE_PROCESSING:
image_urls = self.create_image_url_list(source_documents)
else:
image_urls = []

model = self.env_helper.AZURE_OPENAI_VISION_MODEL if image_urls else None

if self.config.prompts.use_on_your_data_format:
messages = self.generate_on_your_data_messages(
question, chat_history, source_documents
question, chat_history, source_documents, image_urls
)
else:
warnings.warn(
Expand All @@ -145,8 +169,33 @@ def answer_question(self, question: str, chat_history: list[dict], **kwargs):

llm_helper = LLMHelper()

response = llm_helper.get_chat_completion(messages, temperature=0)
response = llm_helper.get_chat_completion(messages, model=model, temperature=0)
clean_answer = self.format_answer_from_response(
response, question, source_documents
)

return clean_answer

def create_image_url_list(self, source_documents):
image_types = self.config.get_advanced_image_processing_image_types()

blob_client = AzureBlobStorageClient()
container_sas = blob_client.get_container_sas()

image_urls = [
doc.source.replace("_SAS_TOKEN_PLACEHOLDER_", container_sas)
for doc in source_documents
if doc.title is not None and doc.title.split(".")[-1] in image_types
]

return image_urls

def format_answer_from_response(
self,
response: ChatCompletion,
question: str,
source_documents: list[SourceDocument],
):
answer = response.choices[0].message.content
logger.debug(f"Answer: {answer}")

Expand All @@ -158,4 +207,5 @@ def answer_question(self, question: str, chat_history: list[dict], **kwargs):
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
)

return clean_answer
32 changes: 4 additions & 28 deletions code/tests/functional/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import pytest
from pytest_httpserver import HTTPServer
from tests.functional.app_config import AppConfig
Expand Down Expand Up @@ -56,7 +57,9 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
)

httpserver.expect_request(
f"/openai/deployments/{app_config.get('AZURE_OPENAI_MODEL')}/chat/completions",
re.compile(
f"/openai/deployments/({app_config.get('AZURE_OPENAI_MODEL')}|{app_config.get('AZURE_OPENAI_VISION_MODEL')})/chat/completions"
),
method="POST",
).respond_with_json(
{
Expand All @@ -82,33 +85,6 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
}
)

httpserver.expect_request(
f"/openai/deployments/{app_config.get('AZURE_OPENAI_VISION_MODEL')}/chat/completions",
method="POST",
).respond_with_json(
{
"id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9",
"object": "chat.completion",
"created": 1679072642,
"model": app_config.get("AZURE_OPENAI_VISION_MODEL"),
"usage": {
"prompt_tokens": 58,
"completion_tokens": 68,
"total_tokens": 126,
},
"choices": [
{
"message": {
"role": "assistant",
"content": "This is a caption for the image",
},
"finish_reason": "stop",
"index": 0,
}
],
}
)

httpserver.expect_request(
f"/indexes('{app_config.get('AZURE_SEARCH_CONVERSATIONS_LOG_INDEX')}')/docs/search.index",
method="POST",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
import re
from unittest.mock import ANY
import pytest
import requests
from pytest_httpserver import HTTPServer
Expand Down Expand Up @@ -68,7 +71,9 @@ def completions_mocking(httpserver: HTTPServer, app_config: AppConfig):
)

httpserver.expect_oneshot_request(
f"/openai/deployments/{app_config.get('AZURE_OPENAI_MODEL')}/chat/completions",
re.compile(
f"/openai/deployments/({app_config.get('AZURE_OPENAI_MODEL')}|{app_config.get('AZURE_OPENAI_VISION_MODEL')})/chat/completions"
),
method="POST",
).respond_with_json(
{
Expand Down Expand Up @@ -112,6 +117,30 @@ def completions_mocking(httpserver: HTTPServer, app_config: AppConfig):
}
)

httpserver.expect_oneshot_request(
f"/indexes('{app_config.get('AZURE_SEARCH_INDEX')}')/docs/search.post.search",
method="POST",
).respond_with_json(
{
"value": [
{
"@search.score": 0.02916666865348816,
"id": "doc_1",
"content": "content",
"content_vector": [
-0.012909674,
0.00838491,
],
"metadata": '{"id": "doc_1", "source": "https://source_SAS_TOKEN_PLACEHOLDER_", "title": "/documents/doc.png", "chunk": 95, "offset": 202738, "page_number": null}',
"title": "/documents/doc.png",
"source": "https://source_SAS_TOKEN_PLACEHOLDER_",
"chunk": 95,
"offset": 202738,
}
]
}
)


def test_post_responds_successfully(app_url: str, app_config: AppConfig):
# when
Expand All @@ -124,7 +153,7 @@ def test_post_responds_successfully(app_url: str, app_config: AppConfig):
{
"messages": [
{
"content": r'{"citations": [{"content": "[/documents/doc.pdf](https://source)\n\n\ncontent", "id": "doc_1", "chunk_id": 95, "title": "/documents/doc.pdf", "filepath": "source", "url": "[/documents/doc.pdf](https://source)", "metadata": {"offset": 202738, "source": "https://source", "markdown_url": "[/documents/doc.pdf](https://source)", "title": "/documents/doc.pdf", "original_url": "https://source", "chunk": 95, "key": "doc_1", "filename": "source"}}], "intent": "What is the meaning of life?"}',
"content": ANY, # SAS URL changes each time
"end_turn": False,
"role": "tool",
},
Expand All @@ -143,6 +172,32 @@ def test_post_responds_successfully(app_url: str, app_config: AppConfig):
}
assert response.headers["Content-Type"] == "application/json"

content = json.loads(response.json()["choices"][0]["messages"][0]["content"])

assert content == {
"citations": [
{
"content": ANY,
"id": "doc_1",
"chunk_id": 95,
"title": "/documents/doc.png",
"filepath": "source",
"url": ANY,
"metadata": {
"offset": 202738,
"source": "https://source_SAS_TOKEN_PLACEHOLDER_",
"markdown_url": ANY,
"title": "/documents/doc.png",
"original_url": "https://source_SAS_TOKEN_PLACEHOLDER_",
"chunk": 95,
"key": "doc_1",
"filename": "source",
},
}
],
"intent": "What is the meaning of life?",
}


def test_text_passed_to_computer_vision_to_generate_text_embeddings(
app_url: str, httpserver: HTTPServer, app_config: AppConfig
Expand All @@ -169,3 +224,68 @@ def test_text_passed_to_computer_vision_to_generate_text_embeddings(
times=1,
),
)


def test_image_urls_included_in_call_to_openai(
app_url: str, app_config: AppConfig, httpserver: HTTPServer
):
# when
requests.post(f"{app_url}{path}", json=body)

# then
request = verify_request_made(
mock_httpserver=httpserver,
request_matcher=RequestMatcher(
path=f"/openai/deployments/{app_config.get('AZURE_OPENAI_VISION_MODEL')}/chat/completions",
method="POST",
json={
"messages": [
{
"content": "system prompt",
"role": "system",
},
{
"content": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nuser question',
"name": "example_user",
"role": "system",
},
{
"content": "answer",
"name": "example_assistant",
"role": "system",
},
{
"content": "You are an AI assistant that helps people find information.",
"role": "system",
},
{"content": "Hello", "role": "user"},
{"content": "Hi, how can I help?", "role": "assistant"},
{
"content": [
{
"type": "text",
"text": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?',
},
{"type": "image_url", "image_url": ANY},
],
"role": "user",
},
],
"model": app_config.get("AZURE_OPENAI_VISION_MODEL"),
"max_tokens": int(app_config.get("AZURE_OPENAI_MAX_TOKENS")),
"temperature": 0,
},
headers={
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {app_config.get('AZURE_OPENAI_API_KEY')}",
"Api-Key": app_config.get("AZURE_OPENAI_API_KEY"),
},
query_string="api-version=2024-02-01",
times=1,
),
)[0]

assert request.json["messages"][6]["content"][1]["image_url"].startswith(
"https://source"
)
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,12 @@ def test_post_makes_correct_call_to_openai_chat_completions_with_documents(
{"content": "Hello", "role": "user"},
{"content": "Hi, how can I help?", "role": "assistant"},
{
"content": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?',
"content": [
{
"type": "text",
"text": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?',
}
],
"role": "user",
},
],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re

import pytest
import requests
Expand Down Expand Up @@ -90,7 +91,9 @@ def completions_mocking(httpserver: HTTPServer, app_config: AppConfig):
)

httpserver.expect_oneshot_request(
f"/openai/deployments/{app_config.get('AZURE_OPENAI_MODEL')}/chat/completions",
re.compile(
f"/openai/deployments/({app_config.get('AZURE_OPENAI_MODEL')}|{app_config.get('AZURE_OPENAI_VISION_MODEL')})/chat/completions"
),
method="POST",
).respond_with_json(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ def test_post_makes_correct_call_to_openai_chat_completions_in_question_answer_t
{"content": "Hello", "role": "user"},
{"content": "Hi, how can I help?", "role": "assistant"},
{
"content": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?',
"content": [
{
"type": "text",
"text": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?',
}
],
"role": "user",
},
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,12 @@ def test_post_makes_correct_call_to_openai_chat_completions_in_question_answer_t
{"content": "Hello", "role": "user"},
{"content": "Hi, how can I help?", "role": "assistant"},
{
"content": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?',
"content": [
{
"type": "text",
"text": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?',
}
],
"role": "user",
},
],
Expand Down
Loading

0 comments on commit 5617f82

Please sign in to comment.