Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT new target class for AWS Bedrock Anthropic Claude models #699

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ed1e3ff
Adding AWS Bedrock Anthropic Claude target class
kmarsh77 Feb 7, 2025
63a9b2e
Adding unit tests for AWSBedrockClaudeTarget class
kmarsh77 Feb 7, 2025
5de223d
Add optional aws dependency (boto3)
kmarsh77 Feb 10, 2025
ac87b28
Update aws_bedrock_claude_target.py
kmarsh77 Feb 10, 2025
45785d6
Adding bedrock claude target class
kmarsh77 Feb 10, 2025
f7a8767
Update __init__.py for new target classes
kmarsh77 Feb 10, 2025
57252d0
Unit test for AWSBedrockClaudeChatTarget
kmarsh77 Feb 10, 2025
f254145
Delete pyrit/prompt_target/aws_bedrock_claude_target.py
kmarsh77 Feb 12, 2025
f0bc2bc
Update __init__.py
kmarsh77 Feb 12, 2025
2ebb519
Delete tests/unit/test_aws_bedrock_claude_target.py
kmarsh77 Feb 12, 2025
258f287
Update aws_bedrock_claude_chat_target.py
kmarsh77 Feb 12, 2025
408c308
Update test_aws_bedrock_claude_chat_target.py
kmarsh77 Feb 12, 2025
5865ac6
Update pyproject.toml
kmarsh77 Feb 12, 2025
01addd1
Update aws_bedrock_claude_chat_target.py
kmarsh77 Feb 12, 2025
627396a
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
kmarsh77 Feb 12, 2025
59dcb7e
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
kmarsh77 Feb 12, 2025
b5c4924
Update aws_bedrock_claude_chat_target.py
kmarsh77 Feb 12, 2025
5d8d7e0
Update aws_bedrock_claude_chat_target.py
kmarsh77 Feb 12, 2025
185bcff
Update aws_bedrock_claude_chat_target.py
kmarsh77 Feb 12, 2025
2630b43
Update test_aws_bedrock_claude_chat_target.py
kmarsh77 Feb 12, 2025
7bcb075
Merge branch 'main' into main
kmarsh77 Feb 19, 2025
6d0485d
Merge branch 'Azure:main' into main
kmarsh77 Feb 19, 2025
0e0e300
Updates to address complaints from pre-commit hooks
kmarsh77 Feb 20, 2025
187cb16
Merge branch 'main' into main
romanlutz Feb 25, 2025
3fd876b
Merge branch 'main' into main
romanlutz Feb 26, 2025
ef3ef17
Merge branch 'main' into main
romanlutz Feb 26, 2025
6d531d5
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
romanlutz Feb 26, 2025
e7c3c54
Adding exceptions for when boto3 isn't installed
kmarsh77 Feb 27, 2025
8ddb596
Adding exceptions for when boto3 isn't installed
kmarsh77 Feb 27, 2025
b80cbda
Merge branch 'main' of https://github.com/kmarsh77/PyRIT
kmarsh77 Feb 27, 2025
b772a9c
Adding noqa statements to pass pre-commit checks
kmarsh77 Feb 28, 2025
e4b10d3
Merge branch 'Azure:main' into main
kmarsh77 Feb 28, 2025
d88919a
Update tests/unit/test_aws_bedrock_claude_chat_target.py
romanlutz Feb 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ playwright = [
"ollama>=0.4.4"
]

aws = ["boto3>=1.36.6"]

all = [
"accelerate==0.34.2",
"azureml-mlflow==1.57.0",
Expand All @@ -139,6 +141,7 @@ all = [
"flask>=3.1.0",
"ollama>=0.4.4",
"types-PyYAML>=6.0.12.9",
"boto3>=1.36.6"
]

[project.scripts]
Expand Down
3 changes: 2 additions & 1 deletion pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pyrit.prompt_target.openai.openai_target import OpenAITarget
from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget


from pyrit.prompt_target.aws_bedrock_claude_chat_target import AWSBedrockClaudeChatTarget
from pyrit.prompt_target.azure_blob_storage_target import AzureBlobStorageTarget
from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget
from pyrit.prompt_target.crucible_target import CrucibleTarget
Expand All @@ -30,6 +30,7 @@
from pyrit.prompt_target.text_target import TextTarget

__all__ = [
"AWSBedrockClaudeChatTarget",
"AzureBlobStorageTarget",
"AzureMLChatTarget",
"CrucibleTarget",
Expand Down
195 changes: 195 additions & 0 deletions pyrit/prompt_target/aws_bedrock_claude_chat_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import base64
import json
import logging
from typing import MutableSequence, Optional

import boto3
from botocore.exceptions import ClientError

from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer
from pyrit.models import (
ChatMessageListDictContent,
PromptRequestResponse,
construct_response_from_request,
)
from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute

logger = logging.getLogger(__name__)


class AWSBedrockClaudeChatTarget(PromptChatTarget):
"""
This class initializes an AWS Bedrock target for any of the Anthropic Claude models.
Local AWS credentials (typically stored in ~/.aws) are used for authentication.
See the following for more information:
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html

Parameters:
model_id (str): The model ID for target claude model
max_tokens (int): maximum number of tokens to generate
temperature (float, optional): The amount of randomness injected into the response.
top_p (float, optional): Use nucleus sampling
top_k (int, optional): Only sample from the top K options for each subsequent token
enable_ssl_verification (bool, optional): whether or not to perform SSL certificate verification
"""

def __init__(
self,
*,
model_id: str,
max_tokens: int,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
enable_ssl_verification: bool = True,
chat_message_normalizer: ChatMessageNormalizer = ChatMessageNop(),
max_requests_per_minute: Optional[int] = None,
) -> None:
super().__init__(max_requests_per_minute=max_requests_per_minute)

self._model_id = model_id
self._max_tokens = max_tokens
self._temperature = temperature
self._top_p = top_p
self._top_k = top_k
self._enable_ssl_verification = enable_ssl_verification
self.chat_message_normalizer = chat_message_normalizer

self._system_prompt = ""

self._valid_image_types = ["jpeg", "png", "webp", "gif"]

@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:

self._validate_request(prompt_request=prompt_request)
request_piece = prompt_request.request_pieces[0]

prompt_req_res_entries = self._memory.get_conversation(conversation_id=request_piece.conversation_id)
prompt_req_res_entries.append(prompt_request)

logger.info(f"Sending the following prompt to the prompt target: {prompt_request}")

messages = await self._build_chat_messages(prompt_req_res_entries)

response = await self._complete_chat_async(messages=messages)

response_entry = construct_response_from_request(request=request_piece, response_text_pieces=[response])

return response_entry

def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
converted_prompt_data_types = [
request_piece.converted_value_data_type for request_piece in prompt_request.request_pieces
]

for prompt_data_type in converted_prompt_data_types:
if prompt_data_type not in ["text", "image_path"]:
raise ValueError("This target only supports text and image_path.")

async def _complete_chat_async(self, messages: list[ChatMessageListDictContent]) -> str:
brt = boto3.client(
service_name="bedrock-runtime", region_name="us-east-1", verify=self._enable_ssl_verification
)

native_request = self._construct_request_body(messages)

request = json.dumps(native_request)

try:
response = await asyncio.to_thread(brt.invoke_model, modelId=self._model_id, body=request)
except (ClientError, Exception) as e:
raise ValueError(f"ERROR: Can't invoke '{self._model_id}'. Reason: {e}")

model_response = json.loads(response["body"].read())

answer = model_response["content"][0]["text"]

logger.info(f'Received the following response from the prompt target "{answer}"')
return answer

def _convert_local_image_to_base64(self, image_path: str) -> str:
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode()

async def _build_chat_messages(
self, prompt_req_res_entries: MutableSequence[PromptRequestResponse]
) -> list[ChatMessageListDictContent]:
chat_messages: list[ChatMessageListDictContent] = []
for prompt_req_resp_entry in prompt_req_res_entries:
prompt_request_pieces = prompt_req_resp_entry.request_pieces

content = []
role = None
for prompt_request_piece in prompt_request_pieces:
role = prompt_request_piece.role
if role == "system":
# Bedrock doesn't allow a message with role==system,
# but it does let you specify system role in a param
self._system_prompt = prompt_request_piece.converted_value
elif prompt_request_piece.converted_value_data_type == "text":
entry = {"type": "text", "text": prompt_request_piece.converted_value}
content.append(entry)
elif prompt_request_piece.converted_value_data_type == "image_path":
image_type = prompt_request_piece.converted_value.split(".")[-1]
if image_type not in self._valid_image_types:
raise ValueError(
f"""Image file {prompt_request_piece.converted_value} must
have valid extension of .jpeg, .png, .webp, or .gif"""
)

data_base64_encoded = self._convert_local_image_to_base64(prompt_request_piece.converted_value)
media_type = "image/" + image_type
entry = {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data_base64_encoded,
}, # type: ignore
}
content.append(entry)
else:
raise ValueError(
f"Multimodal data type {prompt_request_piece.converted_value_data_type} is not yet supported."
)

if not role:
raise ValueError("No role could be determined from the prompt request pieces.")

chat_message = ChatMessageListDictContent(role=role, content=content)
chat_messages.append(chat_message)
return chat_messages

def _construct_request_body(self, messages_list: list[ChatMessageListDictContent]) -> dict:
content = []

for message in messages_list:
if message.role != "system":
entry = {"role": message.role, "content": message.content}
content.append(entry)

data = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": self._max_tokens,
"system": self._system_prompt,
"messages": content,
}

if self._temperature:
data["temperature"] = self._temperature
if self._top_p:
data["top_p"] = self._top_p
if self._top_k:
data["top_k"] = self._top_k

return data

def is_json_response_supported(self) -> bool:
"""Indicates that this target supports JSON response format."""
return False
82 changes: 82 additions & 0 deletions tests/unit/test_aws_bedrock_claude_chat_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
from unittest.mock import MagicMock, patch

import pytest

from pyrit.models import (
ChatMessageListDictContent,
PromptRequestPiece,
PromptRequestResponse,
)
from pyrit.prompt_target.aws_bedrock_claude_chat_target import (
AWSBedrockClaudeChatTarget,
)


@pytest.fixture
def aws_target() -> AWSBedrockClaudeChatTarget:
return AWSBedrockClaudeChatTarget(
model_id="anthropic.claude-v2",
max_tokens=100,
temperature=0.7,
top_p=0.9,
top_k=50,
enable_ssl_verification=True,
)


@pytest.fixture
def mock_prompt_request():
request_piece = PromptRequestPiece(
role="user", original_value="Hello, Claude!", converted_value="Hello, how are you?"
)
return PromptRequestResponse(request_pieces=[request_piece])


@pytest.mark.asyncio
async def test_send_prompt_async(aws_target, mock_prompt_request):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to skip these if boto3 isn't installed. For HuggingFace chat target we do

def is_torch_installed():
    try:
        import torch  # noqa: F401

        return True
    except ModuleNotFoundError:
        return False

@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
...

Maybe you can do the same? Otherwise it'll fail in the cases where boto3 isn't installed.

Similarly, you can't import boto3 at the top of the file, but you need to import inside a try-except block inside your target constructor. See hugging_face_chat_target.py for an example where we do this with torch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made the suggested changes and pushed them to my fork, however note that pre-commit is now complaining that "'boto3' imported but unused". Let me know if I can fix it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added the # noqa: F401 for that (see in the snippet above)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this for the imports in both files, also added the necessary if TYPE_CHECKING block at top of target file.

with patch("boto3.client", new_callable=MagicMock) as mock_boto:
mock_client = mock_boto.return_value
mock_client.invoke_model.return_value = {
"body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "I'm good, thanks!"}]})))
}

response = await aws_target.send_prompt_async(prompt_request=mock_prompt_request)

assert response.request_pieces[0].converted_value == "I'm good, thanks!"


@pytest.mark.asyncio
async def test_validate_request_valid(aws_target, mock_prompt_request):
aws_target._validate_request(prompt_request=mock_prompt_request)


@pytest.mark.asyncio
async def test_validate_request_invalid_data_type(aws_target):
request_pieces = [
PromptRequestPiece(
role="user", original_value="test", converted_value="ImageData", converted_value_data_type="video"
)
]
invalid_request = PromptRequestResponse(request_pieces=request_pieces)

with pytest.raises(ValueError, match="This target only supports text and image_path."):
aws_target._validate_request(prompt_request=invalid_request)


@pytest.mark.asyncio
async def test_complete_chat_async(aws_target):
with patch("boto3.client", new_callable=MagicMock) as mock_boto:
mock_client = mock_boto.return_value
mock_client.invoke_model.return_value = {
"body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "Test Response"}]})))
}

response = await aws_target._complete_chat_async(
messages=[ChatMessageListDictContent(role="user", content=[{"type": "text", "text": "Test input"}])]
)

assert response == "Test Response"
Loading