diff --git a/doc/_toc.yml b/doc/_toc.yml index ff154b0de..db4a48e78 100644 --- a/doc/_toc.yml +++ b/doc/_toc.yml @@ -57,6 +57,7 @@ chapters: - file: code/targets/5_multi_modal_targets - file: code/targets/6_rate_limiting - file: code/targets/7_http_target + - file: code/targets/groq_chat_target - file: code/targets/open_ai_completions - file: code/targets/playwright_target - file: code/targets/prompt_shield_target diff --git a/doc/code/targets/groq_chat_target.ipynb b/doc/code/targets/groq_chat_target.ipynb new file mode 100644 index 000000000..2c25c681d --- /dev/null +++ b/doc/code/targets/groq_chat_target.ipynb @@ -0,0 +1,114 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# GroqChatTarget\n", + "\n", + "This example demonstrates how to use the `GroqChatTarget` class in PyRIT to send a prompt\n", + "to a Groq model and retrieve a response.\n", + "\n", + "## Setup\n", + "Before running this example, you need to set the following environment variables:\n", + "\n", + "```\n", + "export GROQ_API_KEY=\"your_api_key_here\"\n", + "export GROQ_MODEL_NAME=\"llama3-8b-8192\"\n", + "```\n", + "\n", + "Alternatively, you can pass these values as arguments when initializing `GroqChatTarget`:\n", + "\n", + "```python\n", + "groq_target = GroqChatTarget(model_name=\"llama3-8b-8192\", api_key=\"your_api_key_here\")\n", + "```\n", + "\n", + "You can also limit the request rate using `max_requests_per_minute`.\n", + "\n", + "## Example\n", + "The following code initializes `GroqChatTarget`, sends a prompt using `PromptSendingOrchestrator`,\n", + "and retrieves a response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[22m\u001b[39mConversation ID: 7ae4ae98-a23b-4330-9c3e-5fd9e8c37854\n", + "\u001b[1m\u001b[34muser: Why is the sky blue ?\n", + "\u001b[22m\u001b[33massistant: The sky appears blue because of a phenomenon called Rayleigh scattering, which is the scattering of light by small particles or molecules in the atmosphere.\n", + "\n", + "When sunlight enters Earth's atmosphere, it encounters tiny molecules of gases such as nitrogen (N2) and oxygen (O2). These molecules scatter the light in all directions, but they scatter shorter (blue) wavelengths more than longer (red) wavelengths. This is known as Rayleigh scattering.\n", + "\n", + "As a result of this scattering, the blue light is dispersed throughout the atmosphere, reaching our eyes from all directions. This is why the sky appears blue during the daytime, as the blue light is being scattered in all directions and reaching our eyes from all parts of the sky.\n", + "\n", + "In addition to Rayleigh scattering, there are other factors that can affect the color of the sky, such as:\n", + "\n", + "* Mie scattering: This is the scattering of light by larger particles, such as dust, pollen, and water droplets. Mie scattering can give the sky a more orange or pinkish hue during sunrise and sunset.\n", + "* Scattering by cloud droplets: Clouds can scatter light in a way that gives the sky a more white or gray appearance.\n", + "* Atmospheric conditions: Factors such as pollution, dust, and water vapor can also affect the color of the sky, making it appear more hazy or brownish.\n", + "\n", + "Overall, the combination of Rayleigh scattering and other atmospheric effects is what gives the sky its blue color during the daytime.\n" + ] + } + ], + "source": [ + "\n", + "from pyrit.common import IN_MEMORY, initialize_pyrit\n", + "from pyrit.orchestrator import PromptSendingOrchestrator\n", + "from pyrit.prompt_target import GroqChatTarget\n", + "\n", + "initialize_pyrit(memory_db_type=IN_MEMORY)\n", + "\n", + "groq_target = GroqChatTarget()\n", + "\n", + "prompt = \"Why is the sky blue ?\"\n", + "\n", + "orchestrator = PromptSendingOrchestrator(objective_target=groq_target)\n", + "\n", + "response = await orchestrator.send_prompts_async(prompt_list=[prompt]) # type: ignore\n", + "await orchestrator.print_conversations_async() # type: ignore" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "pyrt_env", + "language": "python", + "name": "pyrt_env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/code/targets/groq_chat_target.py b/doc/code/targets/groq_chat_target.py new file mode 100644 index 000000000..3c8746122 --- /dev/null +++ b/doc/code/targets/groq_chat_target.py @@ -0,0 +1,56 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.7 +# kernelspec: +# display_name: pyrt_env +# language: python +# name: pyrt_env +# --- + +# %% [markdown] +# # GroqChatTarget +# +# This example demonstrates how to use the `GroqChatTarget` class in PyRIT to send a prompt +# to a Groq model and retrieve a response. +# +# ## Setup +# Before running this example, you need to set the following environment variables: +# +# ``` +# export GROQ_API_KEY="your_api_key_here" +# export GROQ_MODEL_NAME="llama3-8b-8192" +# ``` +# +# Alternatively, you can pass these values as arguments when initializing `GroqChatTarget`: +# +# ```python +# groq_target = GroqChatTarget(model_name="llama3-8b-8192", api_key="your_api_key_here") +# ``` +# +# You can also limit the request rate using `max_requests_per_minute`. +# +# ## Example +# The following code initializes `GroqChatTarget`, sends a prompt using `PromptSendingOrchestrator`, +# and retrieves a response. +# %% + +from pyrit.common import IN_MEMORY, initialize_pyrit +from pyrit.orchestrator import PromptSendingOrchestrator +from pyrit.prompt_target import GroqChatTarget + +initialize_pyrit(memory_db_type=IN_MEMORY) + +groq_target = GroqChatTarget() + +prompt = "Why is the sky blue ?" + +orchestrator = PromptSendingOrchestrator(objective_target=groq_target) + +response = await orchestrator.send_prompts_async(prompt_list=[prompt]) # type: ignore +await orchestrator.print_conversations_async() # type: ignore diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 6c7742d18..758716275 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -12,6 +12,7 @@ from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget from pyrit.prompt_target.crucible_target import CrucibleTarget from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget +from pyrit.prompt_target.groq_chat_target import GroqChatTarget from pyrit.prompt_target.http_target.http_target import HTTPTarget from pyrit.prompt_target.http_target.http_target_callback_functions import ( get_http_target_json_response_callback_function, @@ -34,6 +35,7 @@ "CrucibleTarget", "GandalfLevel", "GandalfTarget", + "GroqChatTarget", "get_http_target_json_response_callback_function", "get_http_target_regex_matching_callback_function", "HTTPTarget", diff --git a/pyrit/prompt_target/groq_chat_target.py b/pyrit/prompt_target/groq_chat_target.py new file mode 100644 index 000000000..accc21bc2 --- /dev/null +++ b/pyrit/prompt_target/groq_chat_target.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion + +from pyrit.common import default_values +from pyrit.exceptions import EmptyResponseException, PyritException, pyrit_target_retry +from pyrit.models import ChatMessageListDictContent +from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget + +logger = logging.getLogger(__name__) + + +class GroqChatTarget(OpenAIChatTarget): + """ + A chat target for interacting with Groq's OpenAI-compatible API. + + This class extends `OpenAIChatTarget` and ensures compatibility with Groq's API, + which requires `msg.content` to be a string instead of a list of dictionaries. + + Attributes: + API_KEY_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq API key. + MODEL_NAME_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq model name. + GROQ_API_BASE_URL (str): The fixed API base URL for Groq. + """ + + API_KEY_ENVIRONMENT_VARIABLE = "GROQ_API_KEY" + MODEL_NAME_ENVIRONMENT_VARIABLE = "GROQ_MODEL_NAME" + GROQ_API_BASE_URL = "https://api.groq.com/openai/v1/" + + def __init__(self, *, model_name: str = None, api_key: str = None, max_requests_per_minute: int = None, **kwargs): + """ + Initializes GroqChatTarget with the correct API settings. + + Args: + model_name (str, optional): The model to use. Defaults to `GROQ_MODEL_NAME` env variable. + api_key (str, optional): The API key for authentication. Defaults to `GROQ_API_KEY` env variable. + max_requests_per_minute (int, optional): Rate limit for requests. + """ + + kwargs.pop("endpoint", None) + kwargs.pop("deployment_name", None) + + super().__init__( + deployment_name=model_name, + endpoint=self.GROQ_API_BASE_URL, + api_key=api_key, + is_azure_target=False, + max_requests_per_minute=max_requests_per_minute, + **kwargs, + ) + + def _initialize_non_azure_vars(self, deployment_name: str, endpoint: str, api_key: str): + """ + Initializes variables to communicate with the (non-Azure) OpenAI API, in this case Groq. + + Args: + deployment_name (str): The model name. + endpoint (str): The API base URL. + api_key (str): The API key. + + Raises: + ValueError: If _deployment_name or _api_key is missing. + """ + self._api_key = default_values.get_required_value( + env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key + ) + if not self._api_key: + raise ValueError("API key for Groq is missing. Ensure GROQ_API_KEY is set in the environment.") + + self._deployment_name = default_values.get_required_value( + env_var_name=self.MODEL_NAME_ENVIRONMENT_VARIABLE, passed_value=deployment_name + ) + if not self._deployment_name: + raise ValueError("Model name for Groq is missing. Ensure GROQ_MODEL_NAME is set in the environment.") + + # Ignoring mypy type error. The OpenAI client and Azure OpenAI client have the same private base class + self._async_client = AsyncOpenAI( # type: ignore + api_key=self._api_key, default_headers=self._extra_headers, base_url=endpoint + ) + + @pyrit_target_retry + async def _complete_chat_async(self, messages: list[ChatMessageListDictContent], is_json_response: bool) -> str: + """ + Completes asynchronous chat request. + + Sends a chat message to the OpenAI chat model and retrieves the generated response. + This method modifies the request structure to ensure compatibility with Groq, + which requires `msg.content` as a string instead of a list of dictionaries. + msg.content -> msg.content[0].get("text") + + Args: + messages (list[ChatMessageListDictContent]): The chat message objects containing the role and content. + is_json_response (bool): Boolean indicating if the response should be in JSON format. + + Returns: + str: The generated response message. + """ + response: ChatCompletion = await self._async_client.chat.completions.create( + model=self._deployment_name, + max_completion_tokens=self._max_completion_tokens, + max_tokens=self._max_tokens, + temperature=self._temperature, + top_p=self._top_p, + frequency_penalty=self._frequency_penalty, + presence_penalty=self._presence_penalty, + n=1, + stream=False, + seed=self._seed, + messages=[{"role": msg.role, "content": msg.content[0].get("text")} for msg in messages], # type: ignore + response_format={"type": "json_object"} if is_json_response else None, + ) + finish_reason = response.choices[0].finish_reason + extracted_response: str = "" + # finish_reason="stop" means API returned complete message and + # "length" means API returned incomplete message due to max_tokens limit. + if finish_reason in ["stop", "length"]: + extracted_response = self._parse_chat_completion(response) + # Handle empty response + if not extracted_response: + logger.log(logging.ERROR, "The chat returned an empty response.") + raise EmptyResponseException(message="The chat returned an empty response.") + else: + raise PyritException(message=f"Unknown finish_reason {finish_reason}") + + return extracted_response diff --git a/tests/unit/target/test_groq_chat_target.py b/tests/unit/target/test_groq_chat_target.py new file mode 100644 index 000000000..53c6ed568 --- /dev/null +++ b/tests/unit/target/test_groq_chat_target.py @@ -0,0 +1,501 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from contextlib import AbstractAsyncContextManager +from tempfile import NamedTemporaryFile +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from openai import BadRequestError, RateLimitError +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from unit.mocks import get_image_request_piece + +from pyrit.exceptions.exception_classes import EmptyResponseException +from pyrit.memory.duckdb_memory import DuckDBMemory +from pyrit.memory.memory_interface import MemoryInterface +from pyrit.models import ( + ChatMessageListDictContent, + PromptRequestPiece, + PromptRequestResponse, +) +from pyrit.prompt_target import GroqChatTarget + + +@pytest.fixture +def groq_chat_engine(patch_central_database) -> GroqChatTarget: + return GroqChatTarget( + model_name="llama3-8b-8192", + api_key="mock-api-key", + api_version="some_version", + ) + + +@pytest.fixture +def groq_mock_return() -> ChatCompletion: + return ChatCompletion( + id="12345678-1a2b-3c4e5f-a123-12345678abcd", + object="chat.completion", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage(role="assistant", content="hi"), + finish_reason="stop", + logprobs=None, + ) + ], + created=1629389505, + model="llama3-8b-8192", + ) + + +class MockChatCompletionsAsync(AbstractAsyncContextManager): + async def __call__(self, *args, **kwargs): + self.mock_chat_completion = ChatCompletion( + id="12345678-1a2b-3c4e5f-a123-12345678abcd", + object="chat.completion", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage(role="assistant", content="hi"), + finish_reason="stop", + logprobs=None, + ) + ], + created=1629389505, + model="llama3-8b-8192", + ) + return self.mock_chat_completion + + async def __aexit__(self, exc_type, exc, tb): + pass + + async def __aenter__(self): + pass + + +@patch( + "openai.resources.chat.AsyncCompletions.create", + new_callable=lambda: MockChatCompletionsAsync(), +) +@pytest.mark.asyncio +async def test_complete_chat_async_return(groq_mock_return: ChatCompletion, groq_chat_engine: GroqChatTarget): + with patch("openai.resources.chat.Completions.create") as mock_create: + mock_create.return_value = groq_mock_return + ret = await groq_chat_engine._complete_chat_async( + messages=[ChatMessageListDictContent(role="user", content=[{"text": "hello"}])], is_json_response=False + ) + assert ret == "hi" + + +def test_init_with_no_env_var_raises(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError): + GroqChatTarget( + model_name="llama3-8b-8192", + api_key="", + api_version="some_version", + ) + + +def test_init_with_no_deployment_var_raises(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError): + GroqChatTarget() + + +@pytest.mark.asyncio() +async def test_convert_image_to_data_url_file_not_found(groq_chat_engine: GroqChatTarget): + with pytest.raises(FileNotFoundError): + await groq_chat_engine._convert_local_image_to_data_url("nonexistent.jpg") + + +@pytest.mark.asyncio() +async def test_convert_image_with_unsupported_extension(groq_chat_engine: GroqChatTarget): + + with NamedTemporaryFile(mode="w+", suffix=".txt", delete=False) as tmp_file: + tmp_file_name = tmp_file.name + + assert os.path.exists(tmp_file_name) + + with pytest.raises(ValueError) as exc_info: + await groq_chat_engine._convert_local_image_to_data_url(tmp_file_name) + + assert "Unsupported image format" in str(exc_info.value) + + os.remove(tmp_file_name) + + +@pytest.mark.asyncio() +@patch("os.path.exists", return_value=True) +@patch("mimetypes.guess_type", return_value=("image/jpg", None)) +@patch("pyrit.models.data_type_serializer.ImagePathDataTypeSerializer") +@patch("pyrit.memory.CentralMemory.get_memory_instance", return_value=DuckDBMemory(db_path=":memory:")) +async def test_convert_image_to_data_url_success( + mock_get_memory_instance, mock_serializer_class, mock_guess_type, mock_exists, groq_chat_engine: GroqChatTarget +): + with NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: + tmp_file_name = tmp_file.name + mock_serializer_instance = MagicMock() + mock_serializer_instance.read_data_base64 = AsyncMock(return_value="encoded_base64_string") + mock_serializer_class.return_value = mock_serializer_instance + + assert os.path.exists(tmp_file_name) + + result = await groq_chat_engine._convert_local_image_to_data_url(tmp_file_name) + assert "_base64_string" in result + + # Assertions for the mocks + mock_serializer_class.assert_called_once_with( + category="prompt-memory-entries", prompt_text=tmp_file_name, extension=".jpg" + ) + mock_serializer_instance.read_data_base64.assert_called_once() + + os.remove(tmp_file_name) + + +@pytest.mark.asyncio() +async def test_build_chat_messages_with_consistent_roles(groq_chat_engine: GroqChatTarget): + + image_request = get_image_request_piece() + entries = [ + PromptRequestResponse( + request_pieces=[ + PromptRequestPiece( + role="user", + converted_value_data_type="text", + original_value="Hello", + ), + image_request, + ] + ) + ] + with patch.object( + groq_chat_engine, + "_convert_local_image_to_data_url", + return_value="_string", + ): + messages = await groq_chat_engine._build_chat_messages(entries) + + assert len(messages) == 1 + assert messages[0].role == "user" + assert messages[0].content[0]["type"] == "text" # type: ignore + assert messages[0].content[1]["type"] == "image_url" # type: ignore + + os.remove(image_request.original_value) + + +@pytest.mark.asyncio +async def test_build_chat_messages_with_unsupported_data_types(groq_chat_engine: GroqChatTarget): + # Like an image_path, the audio_path requires a file, but doesn't validate any contents + entry = get_image_request_piece() + entry.converted_value_data_type = "audio_path" + + with pytest.raises(ValueError) as excinfo: + await groq_chat_engine._build_chat_messages([PromptRequestResponse(request_pieces=[entry])]) + assert "Multimodal data type audio_path is not yet supported." in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_send_prompt_async_empty_response_adds_to_memory( + groq_mock_return: ChatCompletion, groq_chat_engine: GroqChatTarget +): + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_request_response_to_memory = AsyncMock() + + groq_chat_engine._memory = mock_memory + + with NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: + tmp_file_name = tmp_file.name + assert os.path.exists(tmp_file_name) + prompt_req_resp = PromptRequestResponse( + request_pieces=[ + PromptRequestPiece( + role="user", + conversation_id="12345679", + original_value="hello", + converted_value="hello", + original_value_data_type="text", + converted_value_data_type="text", + prompt_target_identifier={"target": "target-identifier"}, + orchestrator_identifier={"test": "test"}, + labels={"test": "test"}, + ), + PromptRequestPiece( + role="user", + conversation_id="12345679", + original_value=tmp_file_name, + converted_value=tmp_file_name, + original_value_data_type="image_path", + converted_value_data_type="image_path", + prompt_target_identifier={"target": "target-identifier"}, + orchestrator_identifier={"test": "test"}, + labels={"test": "test"}, + ), + ] + ) + # Make assistant response empty + groq_mock_return.choices[0].message.content = "" + with patch.object( + groq_chat_engine, + "_convert_local_image_to_data_url", + return_value="_string", + ): + with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock_create: + mock_create.return_value = groq_mock_return + with pytest.raises(EmptyResponseException) as e: + await groq_chat_engine.send_prompt_async(prompt_request=prompt_req_resp) + groq_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="12345679") + groq_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_req_resp) + assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response." + + +@pytest.mark.asyncio +async def test_send_prompt_async_rate_limit_exception_adds_to_memory( + groq_chat_engine: GroqChatTarget, +): + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_request_response_to_memory = AsyncMock() + + groq_chat_engine._memory = mock_memory + + response = MagicMock() + response.status_code = 429 + mock_complete_chat_async = AsyncMock( + side_effect=RateLimitError("Rate Limit Reached", response=response, body="Rate limit reached") + ) + setattr(groq_chat_engine, "_complete_chat_async", mock_complete_chat_async) + prompt_request = PromptRequestResponse( + request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")] + ) + + with pytest.raises(RateLimitError) as rle: + await groq_chat_engine.send_prompt_async(prompt_request=prompt_request) + groq_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") + groq_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) + + assert str(rle.value) == "Rate Limit Reached" + + +@pytest.mark.asyncio +async def test_send_prompt_async_bad_request_error_adds_to_memory(groq_chat_engine: GroqChatTarget): + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_request_response_to_memory = AsyncMock() + + groq_chat_engine._memory = mock_memory + + response = MagicMock() + response.status_code = 400 + mock_complete_chat_async = AsyncMock( + side_effect=BadRequestError("Bad Request", response=response, body="Bad Request") + ) + setattr(groq_chat_engine, "_complete_chat_async", mock_complete_chat_async) + prompt_request = PromptRequestResponse( + request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")] + ) + + with pytest.raises(BadRequestError) as bre: + await groq_chat_engine.send_prompt_async(prompt_request=prompt_request) + groq_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") + groq_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) + + assert str(bre.value) == "Bad Request" + + +@pytest.mark.asyncio +async def test_send_prompt_async(groq_mock_return: ChatCompletion, groq_chat_engine: GroqChatTarget): + with NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: + tmp_file_name = tmp_file.name + assert os.path.exists(tmp_file_name) + prompt_req_resp = PromptRequestResponse( + request_pieces=[ + PromptRequestPiece( + role="user", + conversation_id="12345679", + original_value="hello", + converted_value="hello", + original_value_data_type="text", + converted_value_data_type="text", + prompt_target_identifier={"target": "target-identifier"}, + orchestrator_identifier={"test": "test"}, + labels={"test": "test"}, + ), + PromptRequestPiece( + role="user", + conversation_id="12345679", + original_value=tmp_file_name, + converted_value=tmp_file_name, + original_value_data_type="image_path", + converted_value_data_type="image_path", + prompt_target_identifier={"target": "target-identifier"}, + orchestrator_identifier={"test": "test"}, + labels={"test": "test"}, + ), + ] + ) + with patch.object( + groq_chat_engine, + "_convert_local_image_to_data_url", + return_value="_string", + ): + with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock_create: + mock_create.return_value = groq_mock_return + response: PromptRequestResponse = await groq_chat_engine.send_prompt_async(prompt_request=prompt_req_resp) + assert len(response.request_pieces) == 1 + assert response.request_pieces[0].converted_value == "hi" + os.remove(tmp_file_name) + + +@pytest.mark.asyncio +async def test_send_prompt_async_empty_response_retries( + groq_mock_return: ChatCompletion, groq_chat_engine: GroqChatTarget +): + with NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: + tmp_file_name = tmp_file.name + assert os.path.exists(tmp_file_name) + prompt_req_resp = PromptRequestResponse( + request_pieces=[ + PromptRequestPiece( + role="user", + conversation_id="12345679", + original_value="hello", + converted_value="hello", + original_value_data_type="text", + converted_value_data_type="text", + prompt_target_identifier={"target": "target-identifier"}, + orchestrator_identifier={"test": "test"}, + labels={"test": "test"}, + ), + PromptRequestPiece( + role="user", + conversation_id="12345679", + original_value=tmp_file_name, + converted_value=tmp_file_name, + original_value_data_type="image_path", + converted_value_data_type="image_path", + prompt_target_identifier={"target": "target-identifier"}, + orchestrator_identifier={"test": "test"}, + labels={"test": "test"}, + ), + ] + ) + # Make assistant response empty + groq_mock_return.choices[0].message.content = "" + with patch.object( + groq_chat_engine, + "_convert_local_image_to_data_url", + return_value="_string", + ): + with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock_create: + mock_create.return_value = groq_mock_return + groq_chat_engine._memory = MagicMock(MemoryInterface) + + with pytest.raises(EmptyResponseException): + await groq_chat_engine.send_prompt_async(prompt_request=prompt_req_resp) + + assert mock_create.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) + + +@pytest.mark.asyncio +async def test_send_prompt_async_rate_limit_exception_retries(groq_chat_engine: GroqChatTarget): + + response = MagicMock() + response.status_code = 429 + mock_complete_chat_async = AsyncMock( + side_effect=RateLimitError("Rate Limit Reached", response=response, body="Rate limit reached") + ) + setattr(groq_chat_engine, "_complete_chat_async", mock_complete_chat_async) + prompt_request = PromptRequestResponse( + request_pieces=[PromptRequestPiece(role="user", conversation_id="12345", original_value="Hello")] + ) + + with pytest.raises(RateLimitError): + await groq_chat_engine.send_prompt_async(prompt_request=prompt_request) + assert mock_complete_chat_async.call_count == os.getenv("RETRY_MAX_NUM_ATTEMPTS") + + +@pytest.mark.asyncio +async def test_send_prompt_async_bad_request_error(groq_chat_engine: GroqChatTarget): + + response = MagicMock() + response.status_code = 400 + mock_complete_chat_async = AsyncMock( + side_effect=BadRequestError("Bad Request Error", response=response, body="Bad request") + ) + setattr(groq_chat_engine, "_complete_chat_async", mock_complete_chat_async) + + prompt_request = PromptRequestResponse( + request_pieces=[PromptRequestPiece(role="user", conversation_id="1236748", original_value="Hello")] + ) + with pytest.raises(BadRequestError) as bre: + await groq_chat_engine.send_prompt_async(prompt_request=prompt_request) + assert str(bre.value) == "Bad Request Error" + + +def test_parse_chat_completion_successful(groq_chat_engine: GroqChatTarget): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = "Test response message" + result = groq_chat_engine._parse_chat_completion(mock_response) + assert result == "Test response message", "The response message was not parsed correctly" + + +def test_validate_request_unsupported_data_types(groq_chat_engine: GroqChatTarget): + + image_piece = get_image_request_piece() + image_piece.converted_value_data_type = "new_unknown_type" # type: ignore + prompt_request = PromptRequestResponse( + request_pieces=[ + PromptRequestPiece(role="user", original_value="Hello", converted_value_data_type="text"), + image_piece, + ] + ) + + with pytest.raises(ValueError) as excinfo: + groq_chat_engine._validate_request(prompt_request=prompt_request) + + assert "This target only supports text and image_path." in str( + excinfo.value + ), "Error not raised for unsupported data types" + + os.remove(image_piece.original_value) + + +def test_is_json_response_supported(groq_chat_engine: GroqChatTarget): + assert groq_chat_engine.is_json_response_supported() is True + + +def test_is_response_format_json_supported(groq_chat_engine: GroqChatTarget): + + request_piece = PromptRequestPiece( + role="user", + original_value="original prompt text", + converted_value="Hello, how are you?", + conversation_id="conversation_1", + sequence=0, + prompt_metadata={"response_format": "json"}, + ) + + result = groq_chat_engine.is_response_format_json(request_piece) + + assert result is True + + +def test_is_response_format_json_no_metadata(groq_chat_engine: GroqChatTarget): + request_piece = PromptRequestPiece( + role="user", + original_value="original prompt text", + converted_value="Hello, how are you?", + conversation_id="conversation_1", + sequence=0, + prompt_metadata=None, + ) + + result = groq_chat_engine.is_response_format_json(request_piece) + + assert result is False