From 061956906826e16bb9e9fef8bf5ebc96787a9276 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 11 Feb 2025 23:13:30 -0800 Subject: [PATCH 1/2] add initial version of multi-turn-orchestrators via scanner! --- pyrit/cli/__main__.py | 54 ++++++++++++++++--- .../basic_multi_turn_attack.yaml | 32 +++++++++++ scanner_configurations/prompt_send.yaml | 11 ++-- 3 files changed, 82 insertions(+), 15 deletions(-) create mode 100644 scanner_configurations/basic_multi_turn_attack.yaml diff --git a/pyrit/cli/__main__.py b/pyrit/cli/__main__.py index 73090ea88..f5a2d17b4 100644 --- a/pyrit/cli/__main__.py +++ b/pyrit/cli/__main__.py @@ -77,11 +77,11 @@ async def validate_config_and_run_async(config: Dict[str, Any], memory_labels: O objective_target = validate_target(config, target_key="objective_target") prompt_converters: list[PromptConverter] = [] # prompt_converters = validate_converters(config) - scorer = None - # TODO: need to find a solution for single/multiple scorers and scoring_targets - # scorers = validate_scorers(config) adversarial_chat = None - # adversarial_chat = validate_adversarial_chat(config) + if "adversarial_chat" in config: + adversarial_chat = validate_target(config, target_key="adversarial_chat") + scoring_target = validate_scoring_target(config, adversarial_chat=adversarial_chat) + objective_scorer = validate_objective_scorer(config, scoring_target=scoring_target) orchestrators = [] for scenario_config in scenarios: @@ -91,7 +91,8 @@ async def validate_config_and_run_async(config: Dict[str, Any], memory_labels: O objective_target=objective_target, adversarial_chat=adversarial_chat, prompt_converters=prompt_converters, - scorer=scorer, + scoring_target=scoring_target, + objective_scorer=objective_scorer, ) ) @@ -130,7 +131,8 @@ def validate_scenario( objective_target: PromptTarget, adversarial_chat: Optional[PromptChatTarget] = None, prompt_converters: Optional[List[PromptConverter]] = None, - scorer: Optional[Scorer] = None, + scoring_target: Optional[PromptChatTarget] = None, + objective_scorer: Optional[Scorer] = None, ) -> Orchestrator: if "type" not in scenario_config: raise KeyError("Scenario must contain a 'type' key.") @@ -150,7 +152,7 @@ def validate_scenario( # Some orchestrator arguments have their own configuration since they # are more complex. They are passed in as args to this function. - complex_arg_names = ["objective_target", "adversarial_chat", "prompt_converters", "scorer"] + complex_arg_names = ["objective_target", "adversarial_chat", "prompt_converters", "scoring_target", "objective_scorer"] for complex_arg_name in complex_arg_names: if complex_arg_name in scenario_args: raise ValueError( @@ -207,6 +209,44 @@ def validate_target(config: Dict[str, Any], target_key: str) -> PromptTarget: return target +def validate_scoring_target(config: Dict[str, Any], adversarial_chat: Optional[PromptChatTarget]) -> PromptChatTarget | None: + if "scoring" not in config: + return None + scoring_config = config["scoring"] + + # If a scoring_target has been configured use it. + # Otherwise, use the adversarial_chat target for scoring. + if "scoring_target" in scoring_config: + return validate_target(scoring_config, target_key="scoring_target") + return adversarial_chat + + +def validate_objective_scorer(config: Dict[str, Any], scoring_target: Optional[PromptChatTarget]) -> Scorer | None: + if "scoring" not in config: + return None + scoring_config = config["scoring"] + if "objective_scorer" not in scoring_config: + return None + + scorer_args = deepcopy(scoring_config["objective_scorer"]) + + if "type" not in scorer_args: + raise KeyError("Scorer definition must contain a 'type' key.") + + scorer_type = scorer_args.pop("type") + + try: + scorer_module = import_module("pyrit.score") + scorer_class = getattr(scorer_module, scorer_type) + except Exception as ex: + raise RuntimeError(f"Failed to import target {scorer_type} from pyrit.score") from ex + + if scoring_target and "chat_target" in inspect.signature(scorer_class.__init__).parameters: + scorer_args["chat_target"] = scoring_target + + return scorer_class(**scorer_args) + + def main(args=None): parsed_args = parse_args(args) config_file = parsed_args.config_file diff --git a/scanner_configurations/basic_multi_turn_attack.yaml b/scanner_configurations/basic_multi_turn_attack.yaml new file mode 100644 index 000000000..a4446f166 --- /dev/null +++ b/scanner_configurations/basic_multi_turn_attack.yaml @@ -0,0 +1,32 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "RedTeamingOrchestrator" + # - type: "CrescendoOrchestrator" + # - type: "TreeOfAttacksWithPruningOrchestrator" + # depth: 2 +objective_target: + type: "OpenAIChatTarget" # "AzureMLChatTarget" | "HuggingFaceEndpointTarget" | ... + # endpoint_env_variable: # in case one wants to use multiple you need to specify the env vars - not yet supported + # api_key_env_variable: + # any arg for targets can be listed here: + # deployment_name_env_variable: + # headers: +# converters: +# - type: "Base64Converter" +# - type: "LeetspeakConverter" +adversarial_chat: + type: "OpenAIChatTarget" + is_azure_target: true +scoring: + # scoring_target is optional. If a target is required but not provided, the adversarial_chat will be used for scoring + # scoring_target: + # type: "OpenAIChatTarget" + objective_scorer: + type: "SelfAskRefusalScorer" +memory_labels: + operator: roakey + operation: op_trash_panda +execution_settings: + type: local # or "azureml" + # parallel_nodes: 4 # how many scenarios to execute in parallel diff --git a/scanner_configurations/prompt_send.yaml b/scanner_configurations/prompt_send.yaml index 28e24f854..2e643f378 100644 --- a/scanner_configurations/prompt_send.yaml +++ b/scanner_configurations/prompt_send.yaml @@ -2,8 +2,6 @@ datasets: - ./pyrit/datasets/seed_prompts/illegal.prompt scenarios: - type: "PromptSendingOrchestrator" - # - type: "CrescendoOrchestrator" - # - type: "TreeOfAttackWithPruningOrchestrator" objective_target: type: "OpenAIChatTarget" # "AzureMLChatTarget" | "HuggingFaceEndpointTarget" | ... # endpoint_env_variable: # in case one wants to use multiple you need to specify the env vars - not yet supported @@ -15,14 +13,11 @@ objective_target: # converters: # - type: "Base64Converter" # - type: "LeetspeakConverter" -# adversarial_chat: -# type: "AzureMLChatTarget" -# ... # scoring: -# scorer: ... +# objective_scorer: ... memory_labels: - operator: romanlutz - operation: scanner_setup_jan25 + operator: roakey + operation: op_trash_panda execution_settings: type: local # or "azureml" # parallel_nodes: 4 # how many scenarios to execute in parallel From e4f4a32d7a631a4870f70ab655d98f8d34560b43 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Sun, 2 Mar 2025 23:23:54 -0800 Subject: [PATCH 2/2] first set of tests --- pyrit/cli/__main__.py | 2 +- .../basic_multi_turn_attack.yaml | 1 + ...d_multiple_orchestrators_args_success.yaml | 28 +++++++ .../multi_turn_crescendo_args_success.yaml | 13 ++++ .../cli/multi_turn_crescendo_success.yaml | 11 +++ .../cli/multi_turn_crescendo_wrong_arg.yaml | 13 ++++ ...n_multiple_orchestrators_args_success.yaml | 27 +++++++ .../unit/cli/multi_turn_rto_args_success.yaml | 11 +++ tests/unit/cli/multi_turn_rto_success.yaml | 11 +++ tests/unit/cli/multi_turn_rto_wrong_arg.yaml | 12 +++ .../unit/cli/multi_turn_tap_args_success.yaml | 14 ++++ tests/unit/cli/multi_turn_tap_success.yaml | 11 +++ tests/unit/cli/multi_turn_tap_wrong_arg.yaml | 14 ++++ tests/unit/cli/multi_turn_template.yaml | 32 ++++++++ .../cli/prompt_send_no_objective_target.yaml | 2 +- .../prompt_send_no_objective_target_type.yaml | 2 +- tests/unit/cli/test_cli.py | 77 +++++++++++++++++-- 17 files changed, 272 insertions(+), 9 deletions(-) create mode 100644 tests/unit/cli/mixed_multiple_orchestrators_args_success.yaml create mode 100644 tests/unit/cli/multi_turn_crescendo_args_success.yaml create mode 100644 tests/unit/cli/multi_turn_crescendo_success.yaml create mode 100644 tests/unit/cli/multi_turn_crescendo_wrong_arg.yaml create mode 100644 tests/unit/cli/multi_turn_multiple_orchestrators_args_success.yaml create mode 100644 tests/unit/cli/multi_turn_rto_args_success.yaml create mode 100644 tests/unit/cli/multi_turn_rto_success.yaml create mode 100644 tests/unit/cli/multi_turn_rto_wrong_arg.yaml create mode 100644 tests/unit/cli/multi_turn_tap_args_success.yaml create mode 100644 tests/unit/cli/multi_turn_tap_success.yaml create mode 100644 tests/unit/cli/multi_turn_tap_wrong_arg.yaml create mode 100644 tests/unit/cli/multi_turn_template.yaml diff --git a/pyrit/cli/__main__.py b/pyrit/cli/__main__.py index f5a2d17b4..87a5b6b34 100644 --- a/pyrit/cli/__main__.py +++ b/pyrit/cli/__main__.py @@ -169,7 +169,7 @@ def validate_scenario( orchestrator = orchestrator_class(**scenario_args) except Exception as ex: - raise ValueError(f"Failed to validate scenario {scenario_type}") from ex + raise ValueError(f"Failed to validate scenario {scenario_type}: {ex}") from ex return orchestrator diff --git a/scanner_configurations/basic_multi_turn_attack.yaml b/scanner_configurations/basic_multi_turn_attack.yaml index a4446f166..8ccade2ba 100644 --- a/scanner_configurations/basic_multi_turn_attack.yaml +++ b/scanner_configurations/basic_multi_turn_attack.yaml @@ -12,6 +12,7 @@ objective_target: # any arg for targets can be listed here: # deployment_name_env_variable: # headers: + is_azure_target: true # converters: # - type: "Base64Converter" # - type: "LeetspeakConverter" diff --git a/tests/unit/cli/mixed_multiple_orchestrators_args_success.yaml b/tests/unit/cli/mixed_multiple_orchestrators_args_success.yaml new file mode 100644 index 000000000..b51fd0c8b --- /dev/null +++ b/tests/unit/cli/mixed_multiple_orchestrators_args_success.yaml @@ -0,0 +1,28 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "PromptSendingOrchestrator" + - type: "RedTeamingOrchestrator" + - type: "CrescendoOrchestrator" + max_turns: 5 + max_backtracks: 3 + - type: "CrescendoOrchestrator" + max_turns: 10 + max_backtracks: 10 + - type: "TreeOfAttacksWithPruningOrchestrator" + depth: 10 + width: 7 + branching_factor: 8 + - type: "TreeOfAttacksWithPruningOrchestrator" + depth: 5 + width: 3 + branching_factor: 6 +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + scoring_target: + type: "OpenAIChatTarget" + objective_scorer: + type: "SelfAskRefusalScorer" \ No newline at end of file diff --git a/tests/unit/cli/multi_turn_crescendo_args_success.yaml b/tests/unit/cli/multi_turn_crescendo_args_success.yaml new file mode 100644 index 000000000..c7ae3320f --- /dev/null +++ b/tests/unit/cli/multi_turn_crescendo_args_success.yaml @@ -0,0 +1,13 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "CrescendoOrchestrator" + max_turns: 8 + max_backtracks: 4 +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + scoring_target: + type: "OpenAIChatTarget" \ No newline at end of file diff --git a/tests/unit/cli/multi_turn_crescendo_success.yaml b/tests/unit/cli/multi_turn_crescendo_success.yaml new file mode 100644 index 000000000..70154aae6 --- /dev/null +++ b/tests/unit/cli/multi_turn_crescendo_success.yaml @@ -0,0 +1,11 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "CrescendoOrchestrator" +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + scoring_target: + type: "OpenAIChatTarget" \ No newline at end of file diff --git a/tests/unit/cli/multi_turn_crescendo_wrong_arg.yaml b/tests/unit/cli/multi_turn_crescendo_wrong_arg.yaml new file mode 100644 index 000000000..99144a4fe --- /dev/null +++ b/tests/unit/cli/multi_turn_crescendo_wrong_arg.yaml @@ -0,0 +1,13 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "CrescendoOrchestrator" + max_turns: 8 + wrong_arg: "wrong" +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + scoring_target: + type: "OpenAIChatTarget" \ No newline at end of file diff --git a/tests/unit/cli/multi_turn_multiple_orchestrators_args_success.yaml b/tests/unit/cli/multi_turn_multiple_orchestrators_args_success.yaml new file mode 100644 index 000000000..551d8edde --- /dev/null +++ b/tests/unit/cli/multi_turn_multiple_orchestrators_args_success.yaml @@ -0,0 +1,27 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "RedTeamingOrchestrator" + - type: "CrescendoOrchestrator" + max_turns: 5 + max_backtracks: 3 + - type: "CrescendoOrchestrator" + max_turns: 10 + max_backtracks: 10 + - type: "TreeOfAttacksWithPruningOrchestrator" + depth: 10 + width: 7 + branching_factor: 8 + - type: "TreeOfAttacksWithPruningOrchestrator" + depth: 5 + width: 3 + branching_factor: 6 +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + scoring_target: + type: "OpenAIChatTarget" + objective_scorer: + type: "SelfAskRefusalScorer" \ No newline at end of file diff --git a/tests/unit/cli/multi_turn_rto_args_success.yaml b/tests/unit/cli/multi_turn_rto_args_success.yaml new file mode 100644 index 000000000..7fd1473d2 --- /dev/null +++ b/tests/unit/cli/multi_turn_rto_args_success.yaml @@ -0,0 +1,11 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "RedTeamingOrchestrator" +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + objective_scorer: + type: "SelfAskRefusalScorer" diff --git a/tests/unit/cli/multi_turn_rto_success.yaml b/tests/unit/cli/multi_turn_rto_success.yaml new file mode 100644 index 000000000..7fd1473d2 --- /dev/null +++ b/tests/unit/cli/multi_turn_rto_success.yaml @@ -0,0 +1,11 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "RedTeamingOrchestrator" +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + objective_scorer: + type: "SelfAskRefusalScorer" diff --git a/tests/unit/cli/multi_turn_rto_wrong_arg.yaml b/tests/unit/cli/multi_turn_rto_wrong_arg.yaml new file mode 100644 index 000000000..de7823beb --- /dev/null +++ b/tests/unit/cli/multi_turn_rto_wrong_arg.yaml @@ -0,0 +1,12 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "RedTeamingOrchestrator" + wrong_arg: "wrong" +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + objective_scorer: + type: "SelfAskRefusalScorer" diff --git a/tests/unit/cli/multi_turn_tap_args_success.yaml b/tests/unit/cli/multi_turn_tap_args_success.yaml new file mode 100644 index 000000000..cd94f68e2 --- /dev/null +++ b/tests/unit/cli/multi_turn_tap_args_success.yaml @@ -0,0 +1,14 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "TreeOfAttacksWithPruningOrchestrator" + depth: 10 + width: 7 + branching_factor: 8 +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + scoring_target: + type: "OpenAIChatTarget" \ No newline at end of file diff --git a/tests/unit/cli/multi_turn_tap_success.yaml b/tests/unit/cli/multi_turn_tap_success.yaml new file mode 100644 index 000000000..9be5a6477 --- /dev/null +++ b/tests/unit/cli/multi_turn_tap_success.yaml @@ -0,0 +1,11 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "TreeOfAttacksWithPruningOrchestrator" +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + scoring_target: + type: "OpenAIChatTarget" \ No newline at end of file diff --git a/tests/unit/cli/multi_turn_tap_wrong_arg.yaml b/tests/unit/cli/multi_turn_tap_wrong_arg.yaml new file mode 100644 index 000000000..6dfc0c15d --- /dev/null +++ b/tests/unit/cli/multi_turn_tap_wrong_arg.yaml @@ -0,0 +1,14 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "TreeOfAttacksWithPruningOrchestrator" + depth: 10 + width: 7 + wrong_arg: "wrong" +objective_target: + type: "OpenAIChatTarget" +adversarial_chat: + type: "OpenAIChatTarget" +scoring: + scoring_target: + type: "OpenAIChatTarget" \ No newline at end of file diff --git a/tests/unit/cli/multi_turn_template.yaml b/tests/unit/cli/multi_turn_template.yaml new file mode 100644 index 000000000..a4446f166 --- /dev/null +++ b/tests/unit/cli/multi_turn_template.yaml @@ -0,0 +1,32 @@ +datasets: + - ./pyrit/datasets/seed_prompts/illegal.prompt +scenarios: + - type: "RedTeamingOrchestrator" + # - type: "CrescendoOrchestrator" + # - type: "TreeOfAttacksWithPruningOrchestrator" + # depth: 2 +objective_target: + type: "OpenAIChatTarget" # "AzureMLChatTarget" | "HuggingFaceEndpointTarget" | ... + # endpoint_env_variable: # in case one wants to use multiple you need to specify the env vars - not yet supported + # api_key_env_variable: + # any arg for targets can be listed here: + # deployment_name_env_variable: + # headers: +# converters: +# - type: "Base64Converter" +# - type: "LeetspeakConverter" +adversarial_chat: + type: "OpenAIChatTarget" + is_azure_target: true +scoring: + # scoring_target is optional. If a target is required but not provided, the adversarial_chat will be used for scoring + # scoring_target: + # type: "OpenAIChatTarget" + objective_scorer: + type: "SelfAskRefusalScorer" +memory_labels: + operator: roakey + operation: op_trash_panda +execution_settings: + type: local # or "azureml" + # parallel_nodes: 4 # how many scenarios to execute in parallel diff --git a/tests/unit/cli/prompt_send_no_objective_target.yaml b/tests/unit/cli/prompt_send_no_objective_target.yaml index b50d9a519..c9da1e812 100644 --- a/tests/unit/cli/prompt_send_no_objective_target.yaml +++ b/tests/unit/cli/prompt_send_no_objective_target.yaml @@ -1,4 +1,4 @@ datasets: - ./pyrit/datasets/seed_prompts/illegal.prompt scenarios: - - type: send_prompts + - type: "PromptSendingOrchestrator" diff --git a/tests/unit/cli/prompt_send_no_objective_target_type.yaml b/tests/unit/cli/prompt_send_no_objective_target_type.yaml index b2179b340..c9404637b 100644 --- a/tests/unit/cli/prompt_send_no_objective_target_type.yaml +++ b/tests/unit/cli/prompt_send_no_objective_target_type.yaml @@ -1,5 +1,5 @@ datasets: - ./pyrit/datasets/seed_prompts/illegal.prompt scenarios: - - type: send_prompts + - type: "PromptSendingOrchestrator" objective_target: diff --git a/tests/unit/cli/test_cli.py b/tests/unit/cli/test_cli.py index ab398c8a1..440a6ebb3 100644 --- a/tests/unit/cli/test_cli.py +++ b/tests/unit/cli/test_cli.py @@ -1,15 +1,63 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import contextlib +import re import shlex from unittest.mock import patch import pytest from pyrit.cli.__main__ import main -from pyrit.orchestrator import PromptSendingOrchestrator +from pyrit.orchestrator import PromptSendingOrchestrator, CrescendoOrchestrator, RedTeamingOrchestrator, TreeOfAttacksWithPruningOrchestrator -test_cases_success = ["--config-file 'tests/unit/cli/prompt_send_success.yaml'"] +test_cases_success = [ + ( + "--config-file 'tests/unit/cli/prompt_send_success.yaml'", + [PromptSendingOrchestrator], + ["send_normalizer_requests_async"] + ), + ( + "--config-file 'tests/unit/cli/multi_turn_rto_success.yaml'", + [RedTeamingOrchestrator], + ["run_attack_async"] + ), + ( + "--config-file 'tests/unit/cli/multi_turn_rto_args_success.yaml'", + [RedTeamingOrchestrator], + ["run_attack_async"] + ), + ( + "--config-file 'tests/unit/cli/multi_turn_crescendo_success.yaml'", + [CrescendoOrchestrator], + ["run_attack_async"] + ), + ( + "--config-file 'tests/unit/cli/multi_turn_crescendo_args_success.yaml'", + [CrescendoOrchestrator], + ["run_attack_async"] + ), + ( + "--config-file 'tests/unit/cli/multi_turn_tap_success.yaml'", + [TreeOfAttacksWithPruningOrchestrator], + ["run_attack_async"] + ), + ( + "--config-file 'tests/unit/cli/multi_turn_tap_args_success.yaml'", + [TreeOfAttacksWithPruningOrchestrator], + ["run_attack_async"] + ), + ( + "--config-file 'tests/unit/cli/multi_turn_multiple_orchestrators_args_success.yaml'", + [TreeOfAttacksWithPruningOrchestrator, CrescendoOrchestrator, RedTeamingOrchestrator], + ["run_attack_async", "run_attack_async", "run_attack_async"] + ), + ( + "--config-file 'tests/unit/cli/mixed_multiple_orchestrators_args_success.yaml'", + [PromptSendingOrchestrator, TreeOfAttacksWithPruningOrchestrator, CrescendoOrchestrator, RedTeamingOrchestrator], + ["send_normalizer_requests_async", "run_attack_async", "run_attack_async", "run_attack_async"] + ), +] test_cases_sys_exit = [ @@ -49,17 +97,34 @@ "Scenario must contain a 'type' key.", KeyError, ), + ( + "--config-file 'tests/unit/cli/multi_turn_rto_wrong_arg.yaml'", + "Failed to validate scenario RedTeamingOrchestrator: RedTeamingOrchestrator.__init__() got an unexpected keyword argument 'wrong_arg'", + ValueError, + ), + ( + "--config-file 'tests/unit/cli/multi_turn_crescendo_wrong_arg.yaml'", + "Failed to validate scenario CrescendoOrchestrator: CrescendoOrchestrator.__init__() got an unexpected keyword argument 'wrong_arg'", + ValueError, + ), + ( + "--config-file 'tests/unit/cli/multi_turn_tap_wrong_arg.yaml'", + "Failed to validate scenario TreeOfAttacksWithPruningOrchestrator: TreeOfAttacksWithPruningOrchestrator.__init__() got an unexpected keyword argument 'wrong_arg'", + ValueError, + ), ] -@pytest.mark.parametrize("command", test_cases_success) +@pytest.mark.parametrize("command, orchestrator_classes, methods", test_cases_success) # Patching OpenAI target initialization which depends on environment variables # which we are not providing here. @patch("pyrit.prompt_target.OpenAIChatTarget._initialize_azure_vars") -def test_cli_success(init_method, command): +def test_cli_pso_success(init_method, command, orchestrator_classes, methods): # Patching the request sending functionality since we don't want to test the orchestrator, # but just the CLI part. - with patch.object(PromptSendingOrchestrator, "send_normalizer_requests_async"): + with contextlib.ExitStack() as stack: + for orchestrator_class, method in zip(orchestrator_classes, methods): + stack.enter_context(patch.object(orchestrator_class, method)) main(shlex.split(command)) @@ -77,5 +142,5 @@ def test_cli_sys_exit(capsys, command, expected_output): # which we are not providing here. @patch("pyrit.prompt_target.OpenAIChatTarget._initialize_azure_vars") def test_cli_error(init_method, command, expected_output, error_type): - with pytest.raises(error_type, match=expected_output): + with pytest.raises(error_type, match=re.escape(expected_output)): main(shlex.split(command))