diff --git a/pyrit/orchestrator/skeleton_key_orchestrator.py b/pyrit/orchestrator/skeleton_key_orchestrator.py index 34051d6f7..047afcca2 100644 --- a/pyrit/orchestrator/skeleton_key_orchestrator.py +++ b/pyrit/orchestrator/skeleton_key_orchestrator.py @@ -23,11 +23,12 @@ PromptConverterConfiguration, ) from pyrit.prompt_target import PromptTarget +from pyrit.orchestrator.single_turn.prompt_sending_orchestrator import PromptSendingOrchestrator logger = logging.getLogger(__name__) -class SkeletonKeyOrchestrator(Orchestrator): +class SkeletonKeyOrchestrator(PromptSendingOrchestrator): """ Creates an orchestrator that executes a skeleton key jailbreak. @@ -60,9 +61,12 @@ def __init__( ensure proper rate limit management. verbose (bool, Optional): If set to True, verbose output will be enabled. Defaults to False. """ - super().__init__(prompt_converters=prompt_converters, verbose=verbose) - - self._prompt_normalizer = PromptNormalizer() + super().__init__( + objective_target=prompt_target, + prompt_converters=prompt_converters, + batch_size=batch_size, + verbose=verbose + ) self._skeleton_key_prompt = ( skeleton_key_prompt @@ -74,80 +78,89 @@ def __init__( .value ) - self._prompt_target = prompt_target - - self._batch_size = batch_size - - async def send_skeleton_key_with_prompt_async( + async def send_skeleton_key_with_prompts_async( self, *, - prompt: str, - ) -> PromptRequestResponse: + prompt_list: list[str], + ) -> list[PromptRequestResponse]: """ - Sends a skeleton key, followed by the attack prompt to the target. - - Args + Sends a skeleton key and prompt to the target for each prompt in a list of prompts. - prompt (str): The prompt to be sent. - prompt_type (PromptDataType, Optional): The type of the prompt (e.g., "text"). Defaults to "text". + Args: + prompt_list (list[str]): The list of prompts to be sent. Returns: - PromptRequestResponse: The response from the prompt target. + list[PromptRequestResponse]: The responses from the prompt target. """ + if hasattr(self._prompt_target, 'rpm') and self._prompt_target.rpm and self._batch_size != 1: + raise ValueError( + "When using a prompt target with max_requests_per_minute, batch_size must be set to 1" + ) + # Create a single conversation ID for the entire sequence conversation_id = str(uuid4()) + metadata = {"conversation_id": conversation_id} - skeleton_key_prompt = SeedPromptGroup(prompts=[SeedPrompt(value=self._skeleton_key_prompt, data_type="text")]) - - converter_configuration = PromptConverterConfiguration(converters=self._prompt_converters) - - await self._prompt_normalizer.send_prompt_async( - seed_prompt_group=skeleton_key_prompt, - conversation_id=conversation_id, - request_converter_configurations=[converter_configuration], - target=self._prompt_target, - labels=self._global_memory_labels, - orchestrator_identifier=self.get_identifier(), + # First, send all skeleton keys + skeleton_keys = [self._skeleton_key_prompt] * len(prompt_list) + await self.send_prompts_async( + prompt_list=skeleton_keys, + metadata=metadata ) - objective_prompt = SeedPromptGroup(prompts=[SeedPrompt(value=prompt, data_type="text")]) - - return await self._prompt_normalizer.send_prompt_async( - seed_prompt_group=objective_prompt, - conversation_id=conversation_id, - request_converter_configurations=[converter_configuration], - target=self._prompt_target, - labels=self._global_memory_labels, - orchestrator_identifier=self.get_identifier(), + # Then send all attack prompts with the same conversation ID + attack_responses = await self.send_prompts_async( + prompt_list=prompt_list, + metadata=metadata ) + + return attack_responses - async def send_skeleton_key_with_prompts_async( + async def send_skeleton_key_with_prompt_async( self, *, - prompt_list: list[str], - ) -> list[PromptRequestResponse]: + prompt: str, + ) -> PromptRequestResponse: """ - Sends a skeleton key and prompt to the target for each prompt in a list of prompts. + Sends a skeleton key, followed by the attack prompt to the target. Args: - prompt_list (list[str]): The list of prompts to be sent. - prompt_type (PromptDataType, Optional): The type of the prompts (e.g., "text"). Defaults to "text". + prompt (str): The prompt to be sent. Returns: - list[PromptRequestResponse]: The responses from the prompt target. + PromptRequestResponse: The response from the prompt target. """ + # Create a single conversation ID + conversation_id = str(uuid4()) + metadata = {"conversation_id": conversation_id} + + # Create normalizer requests for both prompts + skeleton_request = self._create_normalizer_request( + prompt_text=self._skeleton_key_prompt, + prompt_type="text", + converters=self._prompt_converters, + metadata=metadata, + conversation_id=conversation_id, + ) - return await batch_task_async( - task_func=self.send_skeleton_key_with_prompt_async, - task_arguments=["prompt"], - prompt_target=self._prompt_target, - batch_size=self._batch_size, - items_to_batch=[prompt_list], + attack_request = self._create_normalizer_request( + prompt_text=prompt, + prompt_type="text", + converters=self._prompt_converters, + metadata=metadata, + conversation_id=conversation_id, ) + # Send both requests in a single batch + responses = await self.send_normalizer_requests_async( + prompt_request_list=[skeleton_request, attack_request] + ) + + # Return the attack prompt response (second response) + return responses[1] + def print_conversation(self) -> None: """Prints all the conversations that have occured with the prompt target.""" - target_messages = self.get_memory() if not target_messages or len(target_messages) == 0: @@ -158,4 +171,4 @@ def print_conversation(self) -> None: if message.role == "user": print(f"{Style.BRIGHT}{Fore.RED}{message.role}: {message.converted_value}\n") else: - print(f"{Style.BRIGHT}{Fore.GREEN}{message.role}: {message.converted_value}\n") + print(f"{Style.BRIGHT}{Fore.GREEN}{message.role}: {message.converted_value}\n") \ No newline at end of file