From 5617f824e4565fc9b7aa41063fe17221b2df3231 Mon Sep 17 00:00:00 2001 From: Chinedum Echeta <60179183+cecheta@users.noreply.github.com> Date: Fri, 24 May 2024 11:01:04 +0100 Subject: [PATCH] feat: Include images in call to LLM (#971) Co-authored-by: Frances Tibble --- .../utilities/tools/question_answer_tool.py | 62 ++++++++- code/tests/functional/conftest.py | 32 +---- .../default/test_advanced_image_processing.py | 124 +++++++++++++++++- .../backend_api/default/test_conversation.py | 7 +- .../default/test_post_prompt_tool.py | 5 +- .../test_iv_question_answer_tool.py | 7 +- ...est_response_with_search_documents_tool.py | 7 +- .../test_advanced_image_processing.py | 30 +++++ .../tools/test_question_answer_tool.py | 86 +++++++++++- 9 files changed, 315 insertions(+), 45 deletions(-) diff --git a/code/backend/batch/utilities/tools/question_answer_tool.py b/code/backend/batch/utilities/tools/question_answer_tool.py index 269cfac7e..4485a2fdd 100644 --- a/code/backend/batch/utilities/tools/question_answer_tool.py +++ b/code/backend/batch/utilities/tools/question_answer_tool.py @@ -4,11 +4,13 @@ from ..common.answer import Answer from ..common.source_document import SourceDocument +from ..helpers.azure_blob_storage_client import AzureBlobStorageClient from ..helpers.config.config_helper import ConfigHelper from ..helpers.env_helper import EnvHelper from ..helpers.llm_helper import LLMHelper from ..search.search import Search from .answering_tool_base import AnsweringToolBase +from openai.types.chat import ChatCompletion logger = logging.getLogger(__name__) @@ -62,6 +64,7 @@ def generate_on_your_data_messages( question: str, chat_history: list[dict], sources: list[SourceDocument], + image_urls: list[str] = [], ) -> list[dict]: examples = [] @@ -122,10 +125,24 @@ def generate_on_your_data_messages( }, *QuestionAnswerTool.clean_chat_history(chat_history), { - "content": self.config.prompts.answering_user_prompt.format( - sources=documents, - question=question, - ), + "content": [ + { + "type": "text", + "text": self.config.prompts.answering_user_prompt.format( + sources=documents, + question=question, + ), + }, + *( + [ + { + "type": "image_url", + "image_url": image_url, + } + for image_url in image_urls + ] + ), + ], "role": "user", }, ] @@ -133,9 +150,16 @@ def generate_on_your_data_messages( def answer_question(self, question: str, chat_history: list[dict], **kwargs): source_documents = Search.get_source_documents(self.search_handler, question) + if self.env_helper.USE_ADVANCED_IMAGE_PROCESSING: + image_urls = self.create_image_url_list(source_documents) + else: + image_urls = [] + + model = self.env_helper.AZURE_OPENAI_VISION_MODEL if image_urls else None + if self.config.prompts.use_on_your_data_format: messages = self.generate_on_your_data_messages( - question, chat_history, source_documents + question, chat_history, source_documents, image_urls ) else: warnings.warn( @@ -145,8 +169,33 @@ def answer_question(self, question: str, chat_history: list[dict], **kwargs): llm_helper = LLMHelper() - response = llm_helper.get_chat_completion(messages, temperature=0) + response = llm_helper.get_chat_completion(messages, model=model, temperature=0) + clean_answer = self.format_answer_from_response( + response, question, source_documents + ) + return clean_answer + + def create_image_url_list(self, source_documents): + image_types = self.config.get_advanced_image_processing_image_types() + + blob_client = AzureBlobStorageClient() + container_sas = blob_client.get_container_sas() + + image_urls = [ + doc.source.replace("_SAS_TOKEN_PLACEHOLDER_", container_sas) + for doc in source_documents + if doc.title is not None and doc.title.split(".")[-1] in image_types + ] + + return image_urls + + def format_answer_from_response( + self, + response: ChatCompletion, + question: str, + source_documents: list[SourceDocument], + ): answer = response.choices[0].message.content logger.debug(f"Answer: {answer}") @@ -158,4 +207,5 @@ def answer_question(self, question: str, chat_history: list[dict], **kwargs): prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, ) + return clean_answer diff --git a/code/tests/functional/conftest.py b/code/tests/functional/conftest.py index b4a3cb18f..752ea49ac 100644 --- a/code/tests/functional/conftest.py +++ b/code/tests/functional/conftest.py @@ -1,3 +1,4 @@ +import re import pytest from pytest_httpserver import HTTPServer from tests.functional.app_config import AppConfig @@ -56,7 +57,9 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig): ) httpserver.expect_request( - f"/openai/deployments/{app_config.get('AZURE_OPENAI_MODEL')}/chat/completions", + re.compile( + f"/openai/deployments/({app_config.get('AZURE_OPENAI_MODEL')}|{app_config.get('AZURE_OPENAI_VISION_MODEL')})/chat/completions" + ), method="POST", ).respond_with_json( { @@ -82,33 +85,6 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig): } ) - httpserver.expect_request( - f"/openai/deployments/{app_config.get('AZURE_OPENAI_VISION_MODEL')}/chat/completions", - method="POST", - ).respond_with_json( - { - "id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9", - "object": "chat.completion", - "created": 1679072642, - "model": app_config.get("AZURE_OPENAI_VISION_MODEL"), - "usage": { - "prompt_tokens": 58, - "completion_tokens": 68, - "total_tokens": 126, - }, - "choices": [ - { - "message": { - "role": "assistant", - "content": "This is a caption for the image", - }, - "finish_reason": "stop", - "index": 0, - } - ], - } - ) - httpserver.expect_request( f"/indexes('{app_config.get('AZURE_SEARCH_CONVERSATIONS_LOG_INDEX')}')/docs/search.index", method="POST", diff --git a/code/tests/functional/tests/backend_api/default/test_advanced_image_processing.py b/code/tests/functional/tests/backend_api/default/test_advanced_image_processing.py index bca14e580..59d71b12e 100644 --- a/code/tests/functional/tests/backend_api/default/test_advanced_image_processing.py +++ b/code/tests/functional/tests/backend_api/default/test_advanced_image_processing.py @@ -1,3 +1,6 @@ +import json +import re +from unittest.mock import ANY import pytest import requests from pytest_httpserver import HTTPServer @@ -68,7 +71,9 @@ def completions_mocking(httpserver: HTTPServer, app_config: AppConfig): ) httpserver.expect_oneshot_request( - f"/openai/deployments/{app_config.get('AZURE_OPENAI_MODEL')}/chat/completions", + re.compile( + f"/openai/deployments/({app_config.get('AZURE_OPENAI_MODEL')}|{app_config.get('AZURE_OPENAI_VISION_MODEL')})/chat/completions" + ), method="POST", ).respond_with_json( { @@ -112,6 +117,30 @@ def completions_mocking(httpserver: HTTPServer, app_config: AppConfig): } ) + httpserver.expect_oneshot_request( + f"/indexes('{app_config.get('AZURE_SEARCH_INDEX')}')/docs/search.post.search", + method="POST", + ).respond_with_json( + { + "value": [ + { + "@search.score": 0.02916666865348816, + "id": "doc_1", + "content": "content", + "content_vector": [ + -0.012909674, + 0.00838491, + ], + "metadata": '{"id": "doc_1", "source": "https://source_SAS_TOKEN_PLACEHOLDER_", "title": "/documents/doc.png", "chunk": 95, "offset": 202738, "page_number": null}', + "title": "/documents/doc.png", + "source": "https://source_SAS_TOKEN_PLACEHOLDER_", + "chunk": 95, + "offset": 202738, + } + ] + } + ) + def test_post_responds_successfully(app_url: str, app_config: AppConfig): # when @@ -124,7 +153,7 @@ def test_post_responds_successfully(app_url: str, app_config: AppConfig): { "messages": [ { - "content": r'{"citations": [{"content": "[/documents/doc.pdf](https://source)\n\n\ncontent", "id": "doc_1", "chunk_id": 95, "title": "/documents/doc.pdf", "filepath": "source", "url": "[/documents/doc.pdf](https://source)", "metadata": {"offset": 202738, "source": "https://source", "markdown_url": "[/documents/doc.pdf](https://source)", "title": "/documents/doc.pdf", "original_url": "https://source", "chunk": 95, "key": "doc_1", "filename": "source"}}], "intent": "What is the meaning of life?"}', + "content": ANY, # SAS URL changes each time "end_turn": False, "role": "tool", }, @@ -143,6 +172,32 @@ def test_post_responds_successfully(app_url: str, app_config: AppConfig): } assert response.headers["Content-Type"] == "application/json" + content = json.loads(response.json()["choices"][0]["messages"][0]["content"]) + + assert content == { + "citations": [ + { + "content": ANY, + "id": "doc_1", + "chunk_id": 95, + "title": "/documents/doc.png", + "filepath": "source", + "url": ANY, + "metadata": { + "offset": 202738, + "source": "https://source_SAS_TOKEN_PLACEHOLDER_", + "markdown_url": ANY, + "title": "/documents/doc.png", + "original_url": "https://source_SAS_TOKEN_PLACEHOLDER_", + "chunk": 95, + "key": "doc_1", + "filename": "source", + }, + } + ], + "intent": "What is the meaning of life?", + } + def test_text_passed_to_computer_vision_to_generate_text_embeddings( app_url: str, httpserver: HTTPServer, app_config: AppConfig @@ -169,3 +224,68 @@ def test_text_passed_to_computer_vision_to_generate_text_embeddings( times=1, ), ) + + +def test_image_urls_included_in_call_to_openai( + app_url: str, app_config: AppConfig, httpserver: HTTPServer +): + # when + requests.post(f"{app_url}{path}", json=body) + + # then + request = verify_request_made( + mock_httpserver=httpserver, + request_matcher=RequestMatcher( + path=f"/openai/deployments/{app_config.get('AZURE_OPENAI_VISION_MODEL')}/chat/completions", + method="POST", + json={ + "messages": [ + { + "content": "system prompt", + "role": "system", + }, + { + "content": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nuser question', + "name": "example_user", + "role": "system", + }, + { + "content": "answer", + "name": "example_assistant", + "role": "system", + }, + { + "content": "You are an AI assistant that helps people find information.", + "role": "system", + }, + {"content": "Hello", "role": "user"}, + {"content": "Hi, how can I help?", "role": "assistant"}, + { + "content": [ + { + "type": "text", + "text": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?', + }, + {"type": "image_url", "image_url": ANY}, + ], + "role": "user", + }, + ], + "model": app_config.get("AZURE_OPENAI_VISION_MODEL"), + "max_tokens": int(app_config.get("AZURE_OPENAI_MAX_TOKENS")), + "temperature": 0, + }, + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {app_config.get('AZURE_OPENAI_API_KEY')}", + "Api-Key": app_config.get("AZURE_OPENAI_API_KEY"), + }, + query_string="api-version=2024-02-01", + times=1, + ), + )[0] + + assert request.json["messages"][6]["content"][1]["image_url"].startswith( + "https://source" + ) diff --git a/code/tests/functional/tests/backend_api/default/test_conversation.py b/code/tests/functional/tests/backend_api/default/test_conversation.py index 18313fa72..4d8575300 100644 --- a/code/tests/functional/tests/backend_api/default/test_conversation.py +++ b/code/tests/functional/tests/backend_api/default/test_conversation.py @@ -570,7 +570,12 @@ def test_post_makes_correct_call_to_openai_chat_completions_with_documents( {"content": "Hello", "role": "user"}, {"content": "Hi, how can I help?", "role": "assistant"}, { - "content": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?', + "content": [ + { + "type": "text", + "text": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?', + } + ], "role": "user", }, ], diff --git a/code/tests/functional/tests/backend_api/default/test_post_prompt_tool.py b/code/tests/functional/tests/backend_api/default/test_post_prompt_tool.py index dbf06b81c..a5899ea35 100644 --- a/code/tests/functional/tests/backend_api/default/test_post_prompt_tool.py +++ b/code/tests/functional/tests/backend_api/default/test_post_prompt_tool.py @@ -1,4 +1,5 @@ import json +import re import pytest import requests @@ -90,7 +91,9 @@ def completions_mocking(httpserver: HTTPServer, app_config: AppConfig): ) httpserver.expect_oneshot_request( - f"/openai/deployments/{app_config.get('AZURE_OPENAI_MODEL')}/chat/completions", + re.compile( + f"/openai/deployments/({app_config.get('AZURE_OPENAI_MODEL')}|{app_config.get('AZURE_OPENAI_VISION_MODEL')})/chat/completions" + ), method="POST", ).respond_with_json( { diff --git a/code/tests/functional/tests/backend_api/integrated_vectorization_custom_conversation/test_iv_question_answer_tool.py b/code/tests/functional/tests/backend_api/integrated_vectorization_custom_conversation/test_iv_question_answer_tool.py index 70fc8d81b..061d7a4af 100644 --- a/code/tests/functional/tests/backend_api/integrated_vectorization_custom_conversation/test_iv_question_answer_tool.py +++ b/code/tests/functional/tests/backend_api/integrated_vectorization_custom_conversation/test_iv_question_answer_tool.py @@ -228,7 +228,12 @@ def test_post_makes_correct_call_to_openai_chat_completions_in_question_answer_t {"content": "Hello", "role": "user"}, {"content": "Hi, how can I help?", "role": "assistant"}, { - "content": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?', + "content": [ + { + "type": "text", + "text": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?', + } + ], "role": "user", }, ], diff --git a/code/tests/functional/tests/backend_api/sk_orchestrator/test_response_with_search_documents_tool.py b/code/tests/functional/tests/backend_api/sk_orchestrator/test_response_with_search_documents_tool.py index 9c300192d..ba5301231 100644 --- a/code/tests/functional/tests/backend_api/sk_orchestrator/test_response_with_search_documents_tool.py +++ b/code/tests/functional/tests/backend_api/sk_orchestrator/test_response_with_search_documents_tool.py @@ -243,7 +243,12 @@ def test_post_makes_correct_call_to_openai_chat_completions_in_question_answer_t {"content": "Hello", "role": "user"}, {"content": "Hi, how can I help?", "role": "assistant"}, { - "content": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?', + "content": [ + { + "type": "text", + "text": '## Retrieved Documents\n{"retrieved_documents":[{"[doc1]":{"content":"content"}}]}\n\n## User Question\nWhat is the meaning of life?', + } + ], "role": "user", }, ], diff --git a/code/tests/functional/tests/functions/advanced_image_processing/test_advanced_image_processing.py b/code/tests/functional/tests/functions/advanced_image_processing/test_advanced_image_processing.py index f18708f08..bc146f222 100644 --- a/code/tests/functional/tests/functions/advanced_image_processing/test_advanced_image_processing.py +++ b/code/tests/functional/tests/functions/advanced_image_processing/test_advanced_image_processing.py @@ -64,6 +64,36 @@ def setup_blob_metadata_mocking(httpserver: HTTPServer, app_config: AppConfig): ).respond_with_data() +@pytest.fixture(autouse=True) +def setup_caption_response(httpserver: HTTPServer, app_config: AppConfig): + httpserver.expect_oneshot_request( + f"/openai/deployments/{app_config.get('AZURE_OPENAI_VISION_MODEL')}/chat/completions", + method="POST", + ).respond_with_json( + { + "id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9", + "object": "chat.completion", + "created": 1679072642, + "model": app_config.get("AZURE_OPENAI_VISION_MODEL"), + "usage": { + "prompt_tokens": 58, + "completion_tokens": 68, + "total_tokens": 126, + }, + "choices": [ + { + "message": { + "role": "assistant", + "content": "This is a caption for the image", + }, + "finish_reason": "stop", + "index": 0, + } + ], + } + ) + + def test_config_file_is_retrieved_from_storage( message: QueueMessage, httpserver: HTTPServer, app_config: AppConfig ): diff --git a/code/tests/utilities/tools/test_question_answer_tool.py b/code/tests/utilities/tools/test_question_answer_tool.py index 7b6de3fe3..b4411cfde 100644 --- a/code/tests/utilities/tools/test_question_answer_tool.py +++ b/code/tests/utilities/tools/test_question_answer_tool.py @@ -27,6 +27,7 @@ def config_mock(): ) config.example.user_question = "mock example user question" config.example.answer = "mock example answer" + config.get_advanced_image_processing_image_types.return_value = ["jpg", "png"] yield config @@ -39,6 +40,8 @@ def env_helper_mock(): env_helper.AZURE_SEARCH_TOP_K = 1 env_helper.AZURE_SEARCH_FILTER = "mock filter" env_helper.AZURE_SEARCH_USE_INTEGRATED_VECTORIZATION = False + env_helper.USE_ADVANCED_IMAGE_PROCESSING = False + env_helper.AZURE_OPENAI_VISION_MODEL = "mock vision model" yield env_helper @@ -58,6 +61,17 @@ def llm_helper_mock(): yield llm_helper +@pytest.fixture(autouse=True) +def azure_blob_service_mock(): + with patch( + "backend.batch.utilities.tools.question_answer_tool.AzureBlobStorageClient" + ) as mock: + blob_helper = mock.return_value + blob_helper.get_container_sas.return_value = "mock sas" + + yield blob_helper + + @pytest.fixture(autouse=True) def search_handler_mock(): with patch( @@ -86,8 +100,8 @@ def source_documents_mock(): SourceDocument( id="mock id 2", content="mock content 2", - title="mock title 2", - source="mock source 2", + title="mock title 2.jpg", + source="mock source 2_SAS_TOKEN_PLACEHOLDER_", chunk_id="mock chunk id 2", ), ] @@ -159,10 +173,16 @@ def test_correct_prompt_with_few_shot_example(llm_helper_mock: MagicMock): }, {"content": "mock azure openai system message", "role": "system"}, { - "content": 'Sources: {"retrieved_documents":[{"[doc1]":{"content":"mock content"}},{"[doc2]":{"content":"mock content 2"}}]}, Question: mock question', + "content": [ + { + "type": "text", + "text": 'Sources: {"retrieved_documents":[{"[doc1]":{"content":"mock content"}},{"[doc2]":{"content":"mock content 2"}}]}, Question: mock question', + } + ], "role": "user", }, ], + model=None, temperature=0, ) @@ -187,10 +207,16 @@ def test_correct_prompt_without_few_shot_example( {"content": "mock answering system prompt", "role": "system"}, {"content": "mock azure openai system message", "role": "system"}, { - "content": 'Sources: {"retrieved_documents":[{"[doc1]":{"content":"mock content"}},{"[doc2]":{"content":"mock content 2"}}]}, Question: mock question', + "content": [ + { + "type": "text", + "text": 'Sources: {"retrieved_documents":[{"[doc1]":{"content":"mock content"}},{"[doc2]":{"content":"mock content 2"}}]}, Question: mock question', + } + ], "role": "user", }, ], + model=None, temperature=0, ) @@ -226,10 +252,16 @@ def test_correct_prompt_with_few_shot_example_and_chat_history( {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi, how can I help?"}, { - "content": 'Sources: {"retrieved_documents":[{"[doc1]":{"content":"mock content"}},{"[doc2]":{"content":"mock content 2"}}]}, Question: mock question', + "content": [ + { + "type": "text", + "text": 'Sources: {"retrieved_documents":[{"[doc1]":{"content":"mock content"}},{"[doc2]":{"content":"mock content 2"}}]}, Question: mock question', + } + ], "role": "user", }, ], + model=None, temperature=0, ) @@ -260,6 +292,7 @@ def test_non_on_your_data_prompt_correct( "role": "user", }, ], + model=None, temperature=0, ) @@ -274,3 +307,46 @@ def test_json_remove_whitespace(input: str, expected: str): # then assert result == expected + + +def test_use_advanced_vision_processing(env_helper_mock, llm_helper_mock): + # given + env_helper_mock.USE_ADVANCED_IMAGE_PROCESSING = True + tool = QuestionAnswerTool() + + # when + answer = tool.answer_question("mock question", []) + + # then + llm_helper_mock.get_chat_completion.assert_called_once_with( + [ + {"content": "mock answering system prompt", "role": "system"}, + { + "content": 'Sources: {"retrieved_documents":[{"[doc1]":{"content":"mock example content"}}]}, Question: mock example user question', + "name": "example_user", + "role": "system", + }, + { + "content": "mock example answer", + "name": "example_assistant", + "role": "system", + }, + {"content": "mock azure openai system message", "role": "system"}, + { + "content": [ + { + "type": "text", + "text": 'Sources: {"retrieved_documents":[{"[doc1]":{"content":"mock content"}},{"[doc2]":{"content":"mock content 2"}}]}, Question: mock question', + }, + {"type": "image_url", "image_url": "mock source 2mock sas"}, + ], + "role": "user", + }, + ], + model="mock vision model", + temperature=0, + ) + + assert isinstance(answer, Answer) + assert answer.question == "mock question" + assert answer.answer == "mock content"