-
Notifications
You must be signed in to change notification settings - Fork 438
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
Showing
6 changed files
with
803 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0", | ||
"metadata": { | ||
"lines_to_next_cell": 0 | ||
}, | ||
"source": [ | ||
"# GroqChatTarget\n", | ||
"\n", | ||
"This example demonstrates how to use the `GroqChatTarget` class in PyRIT to send a prompt\n", | ||
"to a Groq model and retrieve a response.\n", | ||
"\n", | ||
"## Setup\n", | ||
"Before running this example, you need to set the following environment variables:\n", | ||
"\n", | ||
"```\n", | ||
"export GROQ_API_KEY=\"your_api_key_here\"\n", | ||
"export GROQ_MODEL_NAME=\"llama3-8b-8192\"\n", | ||
"```\n", | ||
"\n", | ||
"Alternatively, you can pass these values as arguments when initializing `GroqChatTarget`:\n", | ||
"\n", | ||
"```python\n", | ||
"groq_target = GroqChatTarget(model_name=\"llama3-8b-8192\", api_key=\"your_api_key_here\")\n", | ||
"```\n", | ||
"\n", | ||
"You can also limit the request rate using `max_requests_per_minute`.\n", | ||
"\n", | ||
"## Example\n", | ||
"The following code initializes `GroqChatTarget`, sends a prompt using `PromptSendingOrchestrator`,\n", | ||
"and retrieves a response." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[22m\u001b[39mConversation ID: 7ae4ae98-a23b-4330-9c3e-5fd9e8c37854\n", | ||
"\u001b[1m\u001b[34muser: Why is the sky blue ?\n", | ||
"\u001b[22m\u001b[33massistant: The sky appears blue because of a phenomenon called Rayleigh scattering, which is the scattering of light by small particles or molecules in the atmosphere.\n", | ||
"\n", | ||
"When sunlight enters Earth's atmosphere, it encounters tiny molecules of gases such as nitrogen (N2) and oxygen (O2). These molecules scatter the light in all directions, but they scatter shorter (blue) wavelengths more than longer (red) wavelengths. This is known as Rayleigh scattering.\n", | ||
"\n", | ||
"As a result of this scattering, the blue light is dispersed throughout the atmosphere, reaching our eyes from all directions. This is why the sky appears blue during the daytime, as the blue light is being scattered in all directions and reaching our eyes from all parts of the sky.\n", | ||
"\n", | ||
"In addition to Rayleigh scattering, there are other factors that can affect the color of the sky, such as:\n", | ||
"\n", | ||
"* Mie scattering: This is the scattering of light by larger particles, such as dust, pollen, and water droplets. Mie scattering can give the sky a more orange or pinkish hue during sunrise and sunset.\n", | ||
"* Scattering by cloud droplets: Clouds can scatter light in a way that gives the sky a more white or gray appearance.\n", | ||
"* Atmospheric conditions: Factors such as pollution, dust, and water vapor can also affect the color of the sky, making it appear more hazy or brownish.\n", | ||
"\n", | ||
"Overall, the combination of Rayleigh scattering and other atmospheric effects is what gives the sky its blue color during the daytime.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"\n", | ||
"from pyrit.common import IN_MEMORY, initialize_pyrit\n", | ||
"from pyrit.orchestrator import PromptSendingOrchestrator\n", | ||
"from pyrit.prompt_target import GroqChatTarget\n", | ||
"\n", | ||
"initialize_pyrit(memory_db_type=IN_MEMORY)\n", | ||
"\n", | ||
"groq_target = GroqChatTarget()\n", | ||
"\n", | ||
"prompt = \"Why is the sky blue ?\"\n", | ||
"\n", | ||
"orchestrator = PromptSendingOrchestrator(objective_target=groq_target)\n", | ||
"\n", | ||
"response = await orchestrator.send_prompts_async(prompt_list=[prompt]) # type: ignore\n", | ||
"await orchestrator.print_conversations_async() # type: ignore" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"jupytext": { | ||
"cell_metadata_filter": "-all" | ||
}, | ||
"kernelspec": { | ||
"display_name": "pyrt_env", | ||
"language": "python", | ||
"name": "pyrt_env" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# --- | ||
# jupyter: | ||
# jupytext: | ||
# cell_metadata_filter: -all | ||
# text_representation: | ||
# extension: .py | ||
# format_name: percent | ||
# format_version: '1.3' | ||
# jupytext_version: 1.16.7 | ||
# kernelspec: | ||
# display_name: pyrt_env | ||
# language: python | ||
# name: pyrt_env | ||
# --- | ||
|
||
# %% [markdown] | ||
# # GroqChatTarget | ||
# | ||
# This example demonstrates how to use the `GroqChatTarget` class in PyRIT to send a prompt | ||
# to a Groq model and retrieve a response. | ||
# | ||
# ## Setup | ||
# Before running this example, you need to set the following environment variables: | ||
# | ||
# ``` | ||
# export GROQ_API_KEY="your_api_key_here" | ||
# export GROQ_MODEL_NAME="llama3-8b-8192" | ||
# ``` | ||
# | ||
# Alternatively, you can pass these values as arguments when initializing `GroqChatTarget`: | ||
# | ||
# ```python | ||
# groq_target = GroqChatTarget(model_name="llama3-8b-8192", api_key="your_api_key_here") | ||
# ``` | ||
# | ||
# You can also limit the request rate using `max_requests_per_minute`. | ||
# | ||
# ## Example | ||
# The following code initializes `GroqChatTarget`, sends a prompt using `PromptSendingOrchestrator`, | ||
# and retrieves a response. | ||
# %% | ||
|
||
from pyrit.common import IN_MEMORY, initialize_pyrit | ||
from pyrit.orchestrator import PromptSendingOrchestrator | ||
from pyrit.prompt_target import GroqChatTarget | ||
|
||
initialize_pyrit(memory_db_type=IN_MEMORY) | ||
|
||
groq_target = GroqChatTarget() | ||
|
||
prompt = "Why is the sky blue ?" | ||
|
||
orchestrator = PromptSendingOrchestrator(objective_target=groq_target) | ||
|
||
response = await orchestrator.send_prompts_async(prompt_list=[prompt]) # type: ignore | ||
await orchestrator.print_conversations_async() # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import logging | ||
|
||
from openai import AsyncOpenAI | ||
from openai.types.chat import ChatCompletion | ||
|
||
from pyrit.common import default_values | ||
from pyrit.exceptions import EmptyResponseException, PyritException, pyrit_target_retry | ||
from pyrit.models import ChatMessageListDictContent | ||
from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GroqChatTarget(OpenAIChatTarget): | ||
""" | ||
A chat target for interacting with Groq's OpenAI-compatible API. | ||
This class extends `OpenAIChatTarget` and ensures compatibility with Groq's API, | ||
which requires `msg.content` to be a string instead of a list of dictionaries. | ||
Attributes: | ||
API_KEY_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq API key. | ||
MODEL_NAME_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq model name. | ||
GROQ_API_BASE_URL (str): The fixed API base URL for Groq. | ||
""" | ||
|
||
API_KEY_ENVIRONMENT_VARIABLE = "GROQ_API_KEY" | ||
MODEL_NAME_ENVIRONMENT_VARIABLE = "GROQ_MODEL_NAME" | ||
GROQ_API_BASE_URL = "https://api.groq.com/openai/v1/" | ||
|
||
def __init__(self, *, model_name: str = None, api_key: str = None, max_requests_per_minute: int = None, **kwargs): | ||
""" | ||
Initializes GroqChatTarget with the correct API settings. | ||
Args: | ||
model_name (str, optional): The model to use. Defaults to `GROQ_MODEL_NAME` env variable. | ||
api_key (str, optional): The API key for authentication. Defaults to `GROQ_API_KEY` env variable. | ||
max_requests_per_minute (int, optional): Rate limit for requests. | ||
""" | ||
|
||
kwargs.pop("endpoint", None) | ||
kwargs.pop("deployment_name", None) | ||
|
||
super().__init__( | ||
deployment_name=model_name, | ||
endpoint=self.GROQ_API_BASE_URL, | ||
api_key=api_key, | ||
is_azure_target=False, | ||
max_requests_per_minute=max_requests_per_minute, | ||
**kwargs, | ||
) | ||
|
||
def _initialize_non_azure_vars(self, deployment_name: str, endpoint: str, api_key: str): | ||
""" | ||
Initializes variables to communicate with the (non-Azure) OpenAI API, in this case Groq. | ||
Args: | ||
deployment_name (str): The model name. | ||
endpoint (str): The API base URL. | ||
api_key (str): The API key. | ||
Raises: | ||
ValueError: If _deployment_name or _api_key is missing. | ||
""" | ||
self._api_key = default_values.get_required_value( | ||
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key | ||
) | ||
if not self._api_key: | ||
raise ValueError("API key for Groq is missing. Ensure GROQ_API_KEY is set in the environment.") | ||
|
||
self._deployment_name = default_values.get_required_value( | ||
env_var_name=self.MODEL_NAME_ENVIRONMENT_VARIABLE, passed_value=deployment_name | ||
) | ||
if not self._deployment_name: | ||
raise ValueError("Model name for Groq is missing. Ensure GROQ_MODEL_NAME is set in the environment.") | ||
|
||
# Ignoring mypy type error. The OpenAI client and Azure OpenAI client have the same private base class | ||
self._async_client = AsyncOpenAI( # type: ignore | ||
api_key=self._api_key, default_headers=self._extra_headers, base_url=endpoint | ||
) | ||
|
||
@pyrit_target_retry | ||
async def _complete_chat_async(self, messages: list[ChatMessageListDictContent], is_json_response: bool) -> str: | ||
""" | ||
Completes asynchronous chat request. | ||
Sends a chat message to the OpenAI chat model and retrieves the generated response. | ||
This method modifies the request structure to ensure compatibility with Groq, | ||
which requires `msg.content` as a string instead of a list of dictionaries. | ||
msg.content -> msg.content[0].get("text") | ||
Args: | ||
messages (list[ChatMessageListDictContent]): The chat message objects containing the role and content. | ||
is_json_response (bool): Boolean indicating if the response should be in JSON format. | ||
Returns: | ||
str: The generated response message. | ||
""" | ||
response: ChatCompletion = await self._async_client.chat.completions.create( | ||
model=self._deployment_name, | ||
max_completion_tokens=self._max_completion_tokens, | ||
max_tokens=self._max_tokens, | ||
temperature=self._temperature, | ||
top_p=self._top_p, | ||
frequency_penalty=self._frequency_penalty, | ||
presence_penalty=self._presence_penalty, | ||
n=1, | ||
stream=False, | ||
seed=self._seed, | ||
messages=[{"role": msg.role, "content": msg.content[0].get("text")} for msg in messages], # type: ignore | ||
response_format={"type": "json_object"} if is_json_response else None, | ||
) | ||
finish_reason = response.choices[0].finish_reason | ||
extracted_response: str = "" | ||
# finish_reason="stop" means API returned complete message and | ||
# "length" means API returned incomplete message due to max_tokens limit. | ||
if finish_reason in ["stop", "length"]: | ||
extracted_response = self._parse_chat_completion(response) | ||
# Handle empty response | ||
if not extracted_response: | ||
logger.log(logging.ERROR, "The chat returned an empty response.") | ||
raise EmptyResponseException(message="The chat returned an empty response.") | ||
else: | ||
raise PyritException(message=f"Unknown finish_reason {finish_reason}") | ||
|
||
return extracted_response |
Oops, something went wrong.