diff --git a/pyproject.toml b/pyproject.toml index 311a9bbf8..0589a4b36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,8 @@ playwright = [ "ollama>=0.4.4" ] +aws = ["boto3>=1.36.6"] + all = [ "accelerate==0.34.2", "azureml-mlflow==1.57.0", @@ -139,6 +141,7 @@ all = [ "flask>=3.1.0", "ollama>=0.4.4", "types-PyYAML>=6.0.12.9", + "boto3>=1.36.6" ] [project.scripts] diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 758716275..27372bd7d 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -7,7 +7,7 @@ from pyrit.prompt_target.openai.openai_target import OpenAITarget from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget - +from pyrit.prompt_target.aws_bedrock_claude_chat_target import AWSBedrockClaudeChatTarget from pyrit.prompt_target.azure_blob_storage_target import AzureBlobStorageTarget from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget from pyrit.prompt_target.crucible_target import CrucibleTarget @@ -30,6 +30,7 @@ from pyrit.prompt_target.text_target import TextTarget __all__ = [ + "AWSBedrockClaudeChatTarget", "AzureBlobStorageTarget", "AzureMLChatTarget", "CrucibleTarget", diff --git a/pyrit/prompt_target/aws_bedrock_claude_chat_target.py b/pyrit/prompt_target/aws_bedrock_claude_chat_target.py new file mode 100644 index 000000000..0e5a268f2 --- /dev/null +++ b/pyrit/prompt_target/aws_bedrock_claude_chat_target.py @@ -0,0 +1,203 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import base64 +import json +import logging +from typing import TYPE_CHECKING, MutableSequence, Optional + +from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer +from pyrit.models import ( + ChatMessageListDictContent, + PromptRequestResponse, + construct_response_from_request, +) +from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import boto3 + from botocore.exceptions import ClientError + + +class AWSBedrockClaudeChatTarget(PromptChatTarget): + """ + This class initializes an AWS Bedrock target for any of the Anthropic Claude models. + Local AWS credentials (typically stored in ~/.aws) are used for authentication. + See the following for more information: + https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + + Parameters: + model_id (str): The model ID for target claude model + max_tokens (int): maximum number of tokens to generate + temperature (float, optional): The amount of randomness injected into the response. + top_p (float, optional): Use nucleus sampling + top_k (int, optional): Only sample from the top K options for each subsequent token + enable_ssl_verification (bool, optional): whether or not to perform SSL certificate verification + """ + + def __init__( + self, + *, + model_id: str, + max_tokens: int, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + enable_ssl_verification: bool = True, + chat_message_normalizer: ChatMessageNormalizer = ChatMessageNop(), + max_requests_per_minute: Optional[int] = None, + ) -> None: + super().__init__(max_requests_per_minute=max_requests_per_minute) + + self._model_id = model_id + self._max_tokens = max_tokens + self._temperature = temperature + self._top_p = top_p + self._top_k = top_k + self._enable_ssl_verification = enable_ssl_verification + self.chat_message_normalizer = chat_message_normalizer + + self._system_prompt = "" + + self._valid_image_types = ["jpeg", "png", "webp", "gif"] + + try: + import boto3 # noqa: F401 + from botocore.exceptions import ClientError # noqa: F401 + except ModuleNotFoundError as e: + logger.error("Could not import boto. You may need to install it via 'pip install pyrit[all]'") + raise e + + @limit_requests_per_minute + async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: + + self._validate_request(prompt_request=prompt_request) + request_piece = prompt_request.request_pieces[0] + + prompt_req_res_entries = self._memory.get_conversation(conversation_id=request_piece.conversation_id) + prompt_req_res_entries.append(prompt_request) + + logger.info(f"Sending the following prompt to the prompt target: {prompt_request}") + + messages = await self._build_chat_messages(prompt_req_res_entries) + + response = await self._complete_chat_async(messages=messages) + + response_entry = construct_response_from_request(request=request_piece, response_text_pieces=[response]) + + return response_entry + + def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: + converted_prompt_data_types = [ + request_piece.converted_value_data_type for request_piece in prompt_request.request_pieces + ] + + for prompt_data_type in converted_prompt_data_types: + if prompt_data_type not in ["text", "image_path"]: + raise ValueError("This target only supports text and image_path.") + + async def _complete_chat_async(self, messages: list[ChatMessageListDictContent]) -> str: + brt = boto3.client( + service_name="bedrock-runtime", region_name="us-east-1", verify=self._enable_ssl_verification + ) + + native_request = self._construct_request_body(messages) + + request = json.dumps(native_request) + + try: + response = await asyncio.to_thread(brt.invoke_model, modelId=self._model_id, body=request) + except (ClientError, Exception) as e: + raise ValueError(f"ERROR: Can't invoke '{self._model_id}'. Reason: {e}") + + model_response = json.loads(response["body"].read()) + + answer = model_response["content"][0]["text"] + + logger.info(f'Received the following response from the prompt target "{answer}"') + return answer + + def _convert_local_image_to_base64(self, image_path: str) -> str: + with open(image_path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return encoded_string.decode() + + async def _build_chat_messages( + self, prompt_req_res_entries: MutableSequence[PromptRequestResponse] + ) -> list[ChatMessageListDictContent]: + chat_messages: list[ChatMessageListDictContent] = [] + for prompt_req_resp_entry in prompt_req_res_entries: + prompt_request_pieces = prompt_req_resp_entry.request_pieces + + content = [] + role = None + for prompt_request_piece in prompt_request_pieces: + role = prompt_request_piece.role + if role == "system": + # Bedrock doesn't allow a message with role==system, + # but it does let you specify system role in a param + self._system_prompt = prompt_request_piece.converted_value + elif prompt_request_piece.converted_value_data_type == "text": + entry = {"type": "text", "text": prompt_request_piece.converted_value} + content.append(entry) + elif prompt_request_piece.converted_value_data_type == "image_path": + image_type = prompt_request_piece.converted_value.split(".")[-1] + if image_type not in self._valid_image_types: + raise ValueError( + f"""Image file {prompt_request_piece.converted_value} must + have valid extension of .jpeg, .png, .webp, or .gif""" + ) + + data_base64_encoded = self._convert_local_image_to_base64(prompt_request_piece.converted_value) + media_type = "image/" + image_type + entry = { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": data_base64_encoded, + }, # type: ignore + } + content.append(entry) + else: + raise ValueError( + f"Multimodal data type {prompt_request_piece.converted_value_data_type} is not yet supported." + ) + + if not role: + raise ValueError("No role could be determined from the prompt request pieces.") + + chat_message = ChatMessageListDictContent(role=role, content=content) + chat_messages.append(chat_message) + return chat_messages + + def _construct_request_body(self, messages_list: list[ChatMessageListDictContent]) -> dict: + content = [] + + for message in messages_list: + if message.role != "system": + entry = {"role": message.role, "content": message.content} + content.append(entry) + + data = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": self._max_tokens, + "system": self._system_prompt, + "messages": content, + } + + if self._temperature: + data["temperature"] = self._temperature + if self._top_p: + data["top_p"] = self._top_p + if self._top_k: + data["top_k"] = self._top_k + + return data + + def is_json_response_supported(self) -> bool: + """Indicates that this target supports JSON response format.""" + return False diff --git a/tests/unit/test_aws_bedrock_claude_chat_target.py b/tests/unit/test_aws_bedrock_claude_chat_target.py new file mode 100644 index 000000000..0d9a86630 --- /dev/null +++ b/tests/unit/test_aws_bedrock_claude_chat_target.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from pyrit.models import ( + ChatMessageListDictContent, + PromptRequestPiece, + PromptRequestResponse, +) +from pyrit.prompt_target.aws_bedrock_claude_chat_target import ( + AWSBedrockClaudeChatTarget, +) + + +def is_boto3_installed(): + try: + import boto3 # noqa: F401 + + return True + except ModuleNotFoundError: + return False + + +@pytest.fixture +def aws_target() -> AWSBedrockClaudeChatTarget: + return AWSBedrockClaudeChatTarget( + model_id="anthropic.claude-v2", + max_tokens=100, + temperature=0.7, + top_p=0.9, + top_k=50, + enable_ssl_verification=True, + ) + + +@pytest.fixture +def mock_prompt_request(): + request_piece = PromptRequestPiece( + role="user", original_value="Hello, Claude!", converted_value="Hello, how are you?" + ) + return PromptRequestResponse(request_pieces=[request_piece]) + + +@pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") +@pytest.mark.asyncio +async def test_send_prompt_async(aws_target, mock_prompt_request): + with patch("boto3.client", new_callable=MagicMock) as mock_boto: + mock_client = mock_boto.return_value + mock_client.invoke_model.return_value = { + "body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "I'm good, thanks!"}]}))) + } + + response = await aws_target.send_prompt_async(prompt_request=mock_prompt_request) + + assert response.request_pieces[0].converted_value == "I'm good, thanks!" + + +@pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") +@pytest.mark.asyncio +async def test_validate_request_valid(aws_target, mock_prompt_request): + aws_target._validate_request(prompt_request=mock_prompt_request) + + +@pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") +@pytest.mark.asyncio +async def test_validate_request_invalid_data_type(aws_target): + request_pieces = [ + PromptRequestPiece( + role="user", original_value="test", converted_value="ImageData", converted_value_data_type="video" + ) + ] + invalid_request = PromptRequestResponse(request_pieces=request_pieces) + + with pytest.raises(ValueError, match="This target only supports text and image_path."): + aws_target._validate_request(prompt_request=invalid_request) + + +@pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") +@pytest.mark.asyncio +async def test_complete_chat_async(aws_target): + with patch("boto3.client", new_callable=MagicMock) as mock_boto: + mock_client = mock_boto.return_value + mock_client.invoke_model.return_value = { + "body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "Test Response"}]}))) + } + + response = await aws_target._complete_chat_async( + messages=[ChatMessageListDictContent(role="user", content=[{"type": "text", "text": "Test input"}])] + ) + + assert response == "Test Response"