Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Supporting the new tongue tied Gandalf levels #356

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
a241984
Adding support for first two levels of Gandalf Tongue Tied
Sep 3, 2024
e3deeb2
Merge branch 'main' into tongue-tied/two-levels
Sep 3, 2024
0e1f325
adding tests for tongue tied scorer and removing duplicate target
Sep 3, 2024
ad84fac
removing stray breakpoint
Sep 3, 2024
f236ba8
FEAT Add SQL Entra Auth for Azure SQL Server (#330)
elgertam Sep 4, 2024
3d61482
[MAINT] Fix typos in OllamaChatTarget (#357)
riedgar-ms Sep 4, 2024
b4121e7
Reuse original Gandalf target and use built-in scorer for Tongue Tied…
s-zanella Sep 5, 2024
8b575b2
Fix typo
s-zanella Sep 5, 2024
ed68c14
Use correct target for the attacker. Use method `_create_normalizer_r…
s-zanella Sep 5, 2024
89ac0af
Remove unused `threshold` parameter in `TrueFalseInverterScore`
s-zanella Sep 5, 2024
10b5e30
Tongue Tied Gandalf notebook (WIP): ad-hoc orchestrator, built-in sco…
s-zanella Sep 5, 2024
7518705
resolve conflict from origin
Sep 5, 2024
87c1eef
Merge branch 'main' into tongue-tied/two-levels
s-zanella Sep 5, 2024
7bbe9ed
Prompt that solves level 2
s-zanella Sep 5, 2024
9b01377
gandalf_tongue_tied_scorer -> GandalfTongueTiedScorer
s-zanella Sep 5, 2024
ba9a36a
Line endings
s-zanella Sep 5, 2024
87fcb31
Add pct file for notebook. Prompt that solves level 3
s-zanella Sep 6, 2024
54267c1
Merge branch 'main' into tongue-tied/two-levels
s-zanella Sep 6, 2024
1611538
Merge branch 'main' into tongue-tied/two-levels
s-zanella Sep 11, 2024
bc7fa49
Simplify answer parsing; fix mypy type errors; trim whitespace
s-zanella Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env_example
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,6 @@ AZURE_SQL_SERVER_CONNECTION_STRING="<Provide DB Azure SQL Server connection stri

# Crucible API Key. You can get yours at: https://crucible.dreadnode.io/login
CRUCIBLE_API_KEY = "<Provide Crucible API key here>"

# Azure SQL Server Connection String
AZURE_SQL_DB_CONNECTION_STRING = "<Provide Azure SQL DB connection string here in SQLAlchemy format>"
11 changes: 5 additions & 6 deletions doc/setup/use_sql_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
In order to connect PyRIT with Azure SQL Server, an Azure SQL Server instance with username & password
authentication enabled is required. If you are creating a new Azure SQL Server resource, be sure to note the password for your "Server Admin." Otherwise, if you have an existing Azure SQL Server resource, you can reset the password from the "Overview" page.

PyRIT does not yet support Microsoft Entra ID (formerly known as Azure Active Directory) when accessing Azure SQL Server. Therefore, ensure your server is configured to take non-Entra connections. To do that, navigate in the Azure Portal to Settings -&gt; Microsoft Entra ID. Under the heading "Microsoft Entra authentication only", uncheck the box reading "Support only Microsoft Entra authentication for this server." Then save the configuration.
SQL Server requires Microsoft Entra authentication (formerly known as Azure Active Directory). To ensure this works, be
sure to install the `az` utility and have it availble on PATH for Python to access.

Finally, firewall rules can prevent you or your team from accessing SQL Server. To ensure you and your team have access, collect any public IP addresses of anyone who may need access to Azure SQL Server while running PyRIT. Once these are collected, navigate in the Azure Portal to Security -&gt; Networking. Under the heading "Firewall rules," click "+ Add a firewall rule" for each IP address that must be granted access. If the rule has only one IP address, copy the vame value into "Start IPv4 Address" and "End IPv4 Address." Then save this configuration.
Firewall rules can prevent you or your team from accessing SQL Server. To ensure you and your team have access, collect any public IP addresses of anyone who may need access to Azure SQL Server while running PyRIT. Once these are collected, navigate in the Azure Portal to Security -&gt; Networking. Under the heading "Firewall rules," click "+ Add a firewall rule" for each IP address that must be granted access. If the rule has only one IP address, copy the vame value into "Start IPv4 Address" and "End IPv4 Address." Then save this configuration.

## Configure SQL Database

Expand All @@ -21,7 +22,7 @@ Connecting PyRIT to an Azure SQL Server database requires ODBC, PyODBC and Micro

Once ODBC and the SQL Server driver have been configured, you must use the `AzureSQLMemory` implementation of `MemoryInterface` from the `pyrit.memory.azure_sql_server` module to connect PyRIT to an Azure SQL Server database.

The constructor for `AzureSQLMemory` requires a URL connection string of the form: `mssql+pyodbc://<username>:<password>@<serverName>.database.windows.net/<databaseName>?driver=<driver string>`, where `<username>` and `<password>` are the SQL Server username and password configured above, `<serverName>` is the "Server name" as specified on the Azure SQL Server "Overview" page, `<databaseName>` is the name of the database instance created above, and `<driver string>` is the driver identifier (likely `ODBC+Driver+18+for+SQL+Server` if you installed the latest version of Microsoft's ODBC driver).
The constructor for `AzureSQLMemory` requires a URL connection string of the form: `mssql+pyodbc://@<serverName>.database.windows.net/<databaseName>?driver=<driver string>`, where `<serverName>` is the "Server name" as specified on the Azure SQL Server "Overview" page, `<databaseName>` is the name of the database instance created above, and `<driver string>` is the driver identifier (likely `ODBC+Driver+18+for+SQL+Server` if you installed the latest version of Microsoft's ODBC driver).

## Use PyRIT with Azure SQL Server

Expand All @@ -35,9 +36,7 @@ from pyrit.memory import AzureSQLServer

default_values.load_default_env()

conn_str = os.environ.get('AZURE_SQL_SERVER_CONNECTION_STRING')

azure_memory = AzureSQLServer(connection_string=conn_str)
azure_memory = AzureSQLServer()
```

Once you have created an instance of `AzureSQLServer`, the code will ensure that your Azure SQL Server database is properly configured with the appropriate tables.
50 changes: 47 additions & 3 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
# Licensed under the MIT license.

import logging
import struct

from contextlib import closing
from typing import Optional, Sequence

from sqlalchemy import create_engine, func, and_
from azure.identity import DefaultAzureCredential
from azure.core.credentials import AccessToken

from sqlalchemy import create_engine, func, and_, event
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session

from pyrit.common import default_values
from pyrit.common.singleton import Singleton
from pyrit.memory.memory_models import EmbeddingData, Base, PromptMemoryEntry, ScoreEntry
from pyrit.memory.memory_interface import MemoryInterface
Expand All @@ -30,16 +35,29 @@ class AzureSQLMemory(MemoryInterface, metaclass=Singleton):
and session management to perform database operations.
"""

def __init__(self, *, connection_string: str, verbose: bool = False):
SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h
TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database
AZURE_SQL_DB_CONNECTION_STRING = "AZURE_SQL_DB_CONNECTION_STRING"

def __init__(self, *, connection_string: Optional[str] = None, verbose: bool = False):
super(AzureSQLMemory, self).__init__()

self._connection_string = connection_string
self._connection_string = default_values.get_required_value(
env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string
)

self.engine = self._create_engine(has_echo=verbose)

self._auth_token = self._create_auth_token()
self._enable_azure_authorization()

self.SessionFactory = sessionmaker(bind=self.engine)
self._create_tables_if_not_exist()

def _create_auth_token(self) -> AccessToken:
azure_credentials = DefaultAzureCredential()
return azure_credentials.get_token(self.TOKEN_URL)

def _create_engine(self, *, has_echo: bool) -> Engine:
"""Creates the SQLAlchemy engine for Azure SQL Server.

Expand All @@ -59,6 +77,32 @@ def _create_engine(self, *, has_echo: bool) -> Engine:
logger.exception(f"Error creating the engine for the database: {e}")
raise

def _enable_azure_authorization(self) -> None:
"""
The following is necessary because of how SQLAlchemy and PyODBC handle connection creation. In PyODBC, the
token is passed outside the connection string in the `connect()` method. Since SQLAlchemy lazy-loads
its connections, we need to set this as a separate argument to the `connect()` method. In SQLALchemy
we do this by hooking into the `do_connect` event, which is fired when a connection is created.

For further details, see:
* <https://docs.sqlalchemy.org/en/20/dialects/mssql.html#connecting-to-databases-with-access-tokens>
* <https://learn.microsoft.com/en-us/azure/azure-sql/database/azure-sql-python-quickstart
#add-code-to-connect-to-azure-sql-database>
"""

@event.listens_for(self.engine, "do_connect")
def provide_token(_dialect, _conn_rec, cargs, cparams):
# remove the "Trusted_Connection" parameter that SQLAlchemy adds
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")

# encode the token
azure_token = self._auth_token.token
azure_token_bytes = azure_token.encode("utf-16-le")
packed_azure_token = struct.pack(f"<I{len(azure_token_bytes)}s", len(azure_token_bytes), azure_token_bytes)

# add the encoded token
cparams["attrs_before"] = {self.SQL_COPT_SS_ACCESS_TOKEN: packed_azure_token}

def _create_tables_if_not_exist(self):
"""
Creates all tables defined in the Base metadata, if they don't already exist in the database.
Expand Down
26 changes: 5 additions & 21 deletions pyrit/prompt_target/gandalf_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class GandalfLevel(enum.Enum):
LEVEL_8 = "gandalf-the-white"
LEVEL_9 = "adventure-1"
LEVEL_10 = "adventure-2"
TONGUE_TIED_LEVEL_1 = "adventure-8"
TONGUE_TIED_LEVEL_2 = "adventure-9"
TONGUE_TIED_LEVEL_3 = "adventure-10"
TONGUE_TIED_LEVEL_4 = "adventure-11"
TONGUE_TIED_LEVEL_5 = "adventure-12"


class GandalfTarget(PromptTarget):
Expand Down Expand Up @@ -60,27 +65,6 @@ def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
if prompt_request.request_pieces[0].converted_value_data_type != "text":
raise ValueError("This target only supports text prompt input.")

async def check_password(self, password: str) -> bool:
"""
Checks if the password is correct

True means the password is correct, False means it is not
"""
payload: dict[str, object] = {
"defender": self._defender,
"password": password,
}

resp = await net_utility.make_request_and_raise_if_error_async(
endpoint_uri=self._endpoint, method="POST", request_body=payload, post_type="data"
)

if not resp.text:
raise ValueError("The chat returned an empty response.")

json_response = resp.json()
return json_response["success"]

async def _complete_text_async(self, text: str) -> str:
payload: dict[str, object] = {
"defender": self._defender,
Expand Down
6 changes: 3 additions & 3 deletions pyrit/prompt_target/prompt_chat_target/ollama_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def _construct_http_body(
self,
messages: list[ChatMessage],
) -> dict:
squased_messages = self.chat_message_normalizer.normalize(messages)
messages_dict = [message.model_dump(exclude_none=True) for message in squased_messages]
squashed_messages = self.chat_message_normalizer.normalize(messages)
messages_list = [message.model_dump(exclude_none=True) for message in squashed_messages]
data = {
"model": self.model_name,
"messages": messages_dict,
"messages": messages_list,
"stream": False,
}
return data
Expand Down
87 changes: 44 additions & 43 deletions pyrit/score/__init__.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.models import Score, ScoreType
from pyrit.score.scorer import Scorer

from pyrit.score.azure_content_filter_scorer import AzureContentFilterScorer
from pyrit.score.float_scale_threshold_scorer import FloatScaleThresholdScorer
from pyrit.score.gandalf_scorer import GandalfScorer
from pyrit.score.human_in_the_loop_scorer import HumanInTheLoopScorer
from pyrit.score.markdown_injection import MarkdownInjectionScorer
from pyrit.score.prompt_shield_scorer import PromptShieldScorer
from pyrit.score.self_ask_category_scorer import SelfAskCategoryScorer, ContentClassifierPaths
from pyrit.score.self_ask_likert_scorer import SelfAskLikertScorer, LikertScalePaths
from pyrit.score.self_ask_scale_scorer import SelfAskScaleScorer, ScalePaths
from pyrit.score.self_ask_true_false_scorer import SelfAskTrueFalseScorer, TrueFalseQuestionPaths
from pyrit.score.substring_scorer import SubStringScorer
from pyrit.score.true_false_inverter_scorer import TrueFalseInverterScorer


__all__ = [
"AzureContentFilterScorer",
"ContentClassifierPaths",
"FloatScaleThresholdScorer",
"GandalfScorer",
"HumanInTheLoopScorer",
"LikertScalePaths",
"MarkdownInjectionScorer",
"MetaScorerQuestionPaths",
"ObjectiveQuestionPaths",
"PromptShieldScorer",
"ScalePaths",
"Score",
"ScoreType",
"Scorer",
"SelfAskCategoryScorer",
"SelfAskLikertScorer",
"SelfAskScaleScorer",
"SelfAskTrueFalseScorer",
"SubStringScorer",
"TrueFalseInverterScorer",
"TrueFalseQuestionPaths",
]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.models import Score, ScoreType
from pyrit.score.scorer import Scorer

from pyrit.score.azure_content_filter_scorer import AzureContentFilterScorer
from pyrit.score.float_scale_threshold_scorer import FloatScaleThresholdScorer
from pyrit.score.gandalf_scorer import GandalfScorer, gandalf_tongue_tied_scorer
from pyrit.score.human_in_the_loop_scorer import HumanInTheLoopScorer
from pyrit.score.markdown_injection import MarkdownInjectionScorer
from pyrit.score.prompt_shield_scorer import PromptShieldScorer
from pyrit.score.self_ask_category_scorer import SelfAskCategoryScorer, ContentClassifierPaths
from pyrit.score.self_ask_likert_scorer import SelfAskLikertScorer, LikertScalePaths
from pyrit.score.self_ask_scale_scorer import SelfAskScaleScorer, ScalePaths
from pyrit.score.self_ask_true_false_scorer import SelfAskTrueFalseScorer, TrueFalseQuestionPaths
from pyrit.score.substring_scorer import SubStringScorer
from pyrit.score.true_false_inverter_scorer import TrueFalseInverterScorer


__all__ = [
"AzureContentFilterScorer",
"ContentClassifierPaths",
"FloatScaleThresholdScorer",
"GandalfScorer",
"gandalf_tongue_tied_scorer",
"HumanInTheLoopScorer",
"LikertScalePaths",
"MarkdownInjectionScorer",
"MetaScorerQuestionPaths",
"ObjectiveQuestionPaths",
"PromptShieldScorer",
"ScalePaths",
"Score",
"ScoreType",
"Scorer",
"SelfAskCategoryScorer",
"SelfAskLikertScorer",
"SelfAskScaleScorer",
"SelfAskTrueFalseScorer",
"SubStringScorer",
"TrueFalseInverterScorer",
"TrueFalseQuestionPaths",
]
Loading