Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Add GroqChatTarget (#704) #705

Merged
merged 4 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ chapters:
- file: code/targets/5_multi_modal_targets
- file: code/targets/6_rate_limiting
- file: code/targets/7_http_target
- file: code/targets/groq_chat_target
- file: code/targets/open_ai_completions
- file: code/targets/playwright_target
- file: code/targets/prompt_shield_target
Expand Down
114 changes: 114 additions & 0 deletions doc/code/targets/groq_chat_target.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"# GroqChatTarget\n",
"\n",
"This example demonstrates how to use the `GroqChatTarget` class in PyRIT to send a prompt\n",
"to a Groq model and retrieve a response.\n",
"\n",
"## Setup\n",
"Before running this example, you need to set the following environment variables:\n",
"\n",
"```\n",
"export GROQ_API_KEY=\"your_api_key_here\"\n",
"export GROQ_MODEL_NAME=\"llama3-8b-8192\"\n",
"```\n",
"\n",
"Alternatively, you can pass these values as arguments when initializing `GroqChatTarget`:\n",
"\n",
"```python\n",
"groq_target = GroqChatTarget(model_name=\"llama3-8b-8192\", api_key=\"your_api_key_here\")\n",
"```\n",
"\n",
"You can also limit the request rate using `max_requests_per_minute`.\n",
"\n",
"## Example\n",
"The following code initializes `GroqChatTarget`, sends a prompt using `PromptSendingOrchestrator`,\n",
"and retrieves a response."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[22m\u001b[39mConversation ID: 7ae4ae98-a23b-4330-9c3e-5fd9e8c37854\n",
"\u001b[1m\u001b[34muser: Why is the sky blue ?\n",
"\u001b[22m\u001b[33massistant: The sky appears blue because of a phenomenon called Rayleigh scattering, which is the scattering of light by small particles or molecules in the atmosphere.\n",
"\n",
"When sunlight enters Earth's atmosphere, it encounters tiny molecules of gases such as nitrogen (N2) and oxygen (O2). These molecules scatter the light in all directions, but they scatter shorter (blue) wavelengths more than longer (red) wavelengths. This is known as Rayleigh scattering.\n",
"\n",
"As a result of this scattering, the blue light is dispersed throughout the atmosphere, reaching our eyes from all directions. This is why the sky appears blue during the daytime, as the blue light is being scattered in all directions and reaching our eyes from all parts of the sky.\n",
"\n",
"In addition to Rayleigh scattering, there are other factors that can affect the color of the sky, such as:\n",
"\n",
"* Mie scattering: This is the scattering of light by larger particles, such as dust, pollen, and water droplets. Mie scattering can give the sky a more orange or pinkish hue during sunrise and sunset.\n",
"* Scattering by cloud droplets: Clouds can scatter light in a way that gives the sky a more white or gray appearance.\n",
"* Atmospheric conditions: Factors such as pollution, dust, and water vapor can also affect the color of the sky, making it appear more hazy or brownish.\n",
"\n",
"Overall, the combination of Rayleigh scattering and other atmospheric effects is what gives the sky its blue color during the daytime.\n"
]
}
],
"source": [
"\n",
"from pyrit.common import IN_MEMORY, initialize_pyrit\n",
"from pyrit.orchestrator import PromptSendingOrchestrator\n",
"from pyrit.prompt_target import GroqChatTarget\n",
"\n",
"initialize_pyrit(memory_db_type=IN_MEMORY)\n",
"\n",
"groq_target = GroqChatTarget()\n",
"\n",
"prompt = \"Why is the sky blue ?\"\n",
"\n",
"orchestrator = PromptSendingOrchestrator(objective_target=groq_target)\n",
"\n",
"response = await orchestrator.send_prompts_async(prompt_list=[prompt]) # type: ignore\n",
"await orchestrator.print_conversations_async() # type: ignore"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "pyrt_env",
"language": "python",
"name": "pyrt_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
56 changes: 56 additions & 0 deletions doc/code/targets/groq_chat_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.7
# kernelspec:
# display_name: pyrt_env
# language: python
# name: pyrt_env
# ---

# %% [markdown]
# # GroqChatTarget
#
# This example demonstrates how to use the `GroqChatTarget` class in PyRIT to send a prompt
# to a Groq model and retrieve a response.
#
# ## Setup
# Before running this example, you need to set the following environment variables:
#
# ```
# export GROQ_API_KEY="your_api_key_here"
# export GROQ_MODEL_NAME="llama3-8b-8192"
# ```
#
# Alternatively, you can pass these values as arguments when initializing `GroqChatTarget`:
#
# ```python
# groq_target = GroqChatTarget(model_name="llama3-8b-8192", api_key="your_api_key_here")
# ```
#
# You can also limit the request rate using `max_requests_per_minute`.
#
# ## Example
# The following code initializes `GroqChatTarget`, sends a prompt using `PromptSendingOrchestrator`,
# and retrieves a response.
# %%

from pyrit.common import IN_MEMORY, initialize_pyrit
from pyrit.orchestrator import PromptSendingOrchestrator
from pyrit.prompt_target import GroqChatTarget

initialize_pyrit(memory_db_type=IN_MEMORY)

groq_target = GroqChatTarget()

prompt = "Why is the sky blue ?"

orchestrator = PromptSendingOrchestrator(objective_target=groq_target)

response = await orchestrator.send_prompts_async(prompt_list=[prompt]) # type: ignore
await orchestrator.print_conversations_async() # type: ignore
2 changes: 2 additions & 0 deletions pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget
from pyrit.prompt_target.crucible_target import CrucibleTarget
from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget
from pyrit.prompt_target.groq_chat_target import GroqChatTarget
from pyrit.prompt_target.http_target.http_target import HTTPTarget
from pyrit.prompt_target.http_target.http_target_callback_functions import (
get_http_target_json_response_callback_function,
Expand All @@ -34,6 +35,7 @@
"CrucibleTarget",
"GandalfLevel",
"GandalfTarget",
"GroqChatTarget",
"get_http_target_json_response_callback_function",
"get_http_target_regex_matching_callback_function",
"HTTPTarget",
Expand Down
129 changes: 129 additions & 0 deletions pyrit/prompt_target/groq_chat_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging

from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion

from pyrit.common import default_values
from pyrit.exceptions import EmptyResponseException, PyritException, pyrit_target_retry
from pyrit.models import ChatMessageListDictContent
from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget

logger = logging.getLogger(__name__)


class GroqChatTarget(OpenAIChatTarget):
"""
A chat target for interacting with Groq's OpenAI-compatible API.

This class extends `OpenAIChatTarget` and ensures compatibility with Groq's API,
which requires `msg.content` to be a string instead of a list of dictionaries.

Attributes:
API_KEY_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq API key.
MODEL_NAME_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq model name.
GROQ_API_BASE_URL (str): The fixed API base URL for Groq.
"""

API_KEY_ENVIRONMENT_VARIABLE = "GROQ_API_KEY"
MODEL_NAME_ENVIRONMENT_VARIABLE = "GROQ_MODEL_NAME"
GROQ_API_BASE_URL = "https://api.groq.com/openai/v1/"

def __init__(self, *, model_name: str = None, api_key: str = None, max_requests_per_minute: int = None, **kwargs):
"""
Initializes GroqChatTarget with the correct API settings.

Args:
model_name (str, optional): The model to use. Defaults to `GROQ_MODEL_NAME` env variable.
api_key (str, optional): The API key for authentication. Defaults to `GROQ_API_KEY` env variable.
max_requests_per_minute (int, optional): Rate limit for requests.
"""

kwargs.pop("endpoint", None)
kwargs.pop("deployment_name", None)

super().__init__(
deployment_name=model_name,
endpoint=self.GROQ_API_BASE_URL,
api_key=api_key,
is_azure_target=False,
max_requests_per_minute=max_requests_per_minute,
**kwargs,
)

def _initialize_non_azure_vars(self, deployment_name: str, endpoint: str, api_key: str):
"""
Initializes variables to communicate with the (non-Azure) OpenAI API, in this case Groq.

Args:
deployment_name (str): The model name.
endpoint (str): The API base URL.
api_key (str): The API key.

Raises:
ValueError: If _deployment_name or _api_key is missing.
"""
self._api_key = default_values.get_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
if not self._api_key:
raise ValueError("API key for Groq is missing. Ensure GROQ_API_KEY is set in the environment.")

self._deployment_name = default_values.get_required_value(
env_var_name=self.MODEL_NAME_ENVIRONMENT_VARIABLE, passed_value=deployment_name
)
if not self._deployment_name:
raise ValueError("Model name for Groq is missing. Ensure GROQ_MODEL_NAME is set in the environment.")

# Ignoring mypy type error. The OpenAI client and Azure OpenAI client have the same private base class
self._async_client = AsyncOpenAI( # type: ignore
api_key=self._api_key, default_headers=self._extra_headers, base_url=endpoint
)

@pyrit_target_retry
async def _complete_chat_async(self, messages: list[ChatMessageListDictContent], is_json_response: bool) -> str:
"""
Completes asynchronous chat request.

Sends a chat message to the OpenAI chat model and retrieves the generated response.
This method modifies the request structure to ensure compatibility with Groq,
which requires `msg.content` as a string instead of a list of dictionaries.
msg.content -> msg.content[0].get("text")

Args:
messages (list[ChatMessageListDictContent]): The chat message objects containing the role and content.
is_json_response (bool): Boolean indicating if the response should be in JSON format.

Returns:
str: The generated response message.
"""
response: ChatCompletion = await self._async_client.chat.completions.create(
model=self._deployment_name,
max_completion_tokens=self._max_completion_tokens,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
frequency_penalty=self._frequency_penalty,
presence_penalty=self._presence_penalty,
n=1,
stream=False,
seed=self._seed,
messages=[{"role": msg.role, "content": msg.content[0].get("text")} for msg in messages], # type: ignore
response_format={"type": "json_object"} if is_json_response else None,
)
finish_reason = response.choices[0].finish_reason
extracted_response: str = ""
# finish_reason="stop" means API returned complete message and
# "length" means API returned incomplete message due to max_tokens limit.
if finish_reason in ["stop", "length"]:
extracted_response = self._parse_chat_completion(response)
# Handle empty response
if not extracted_response:
logger.log(logging.ERROR, "The chat returned an empty response.")
raise EmptyResponseException(message="The chat returned an empty response.")
else:
raise PyritException(message=f"Unknown finish_reason {finish_reason}")

return extracted_response
Loading
Loading