From 7c4bdb83f23ddc8bbbdffca124256b91d72b017b Mon Sep 17 00:00:00 2001 From: MoolmanM Date: Fri, 21 Feb 2025 00:27:07 +0200 Subject: [PATCH] FIX: Refactor fetch_red_team_social_bias_prompts function --- pyrit/datasets/fetch_example_datasets.py | 79 ++++++++++++++++++------ 1 file changed, 59 insertions(+), 20 deletions(-) diff --git a/pyrit/datasets/fetch_example_datasets.py b/pyrit/datasets/fetch_example_datasets.py index d9a6fdab9..56c35eb84 100644 --- a/pyrit/datasets/fetch_example_datasets.py +++ b/pyrit/datasets/fetch_example_datasets.py @@ -512,32 +512,71 @@ def fetch_red_team_social_bias_prompts_dataset() -> SeedPromptDataset: """ data = load_dataset("svannie678/red_team_repo_social_bias_prompts", "default") - # Filter items where prompt_type is "Single Prompt" - filtered_data = [item for item in data["train"] if item.get("prompt_type") == "Single Prompt"] + # Define common metadata that will be used across all seed prompts + common_metadata = { + "dataset_name": "svannie678/red_team_repo_social_bias_prompts", + "authors": ["Simone van Taylor"], + "description": """This dataset contains aggregated and unified existing + red-teaming prompts designed to identify stereotypes, + discrimination, hate speech, and other representation + harms in text-based Large Language Models (LLMs).""", + "source": "https://huggingface.co/datasets/svannie678/red_team_repo_social_bias_prompts" + } - seed_prompts = [ - SeedPrompt( - value=item.get("prompt", ""), - data_type="text", - name="svannie678/red_team_repo_social_bias_prompts", - dataset_name="svannie678/red_team_repo_social_bias_prompts", - authors=["Simone van Taylor"], - description="""This dataset contains aggregated and unified existing - red-teaming prompts designed to identify stereotypes, - discrimination, hate speech, and other representation - harms in text-based Large Language Models (LLMs).""", - source="https://huggingface.co/datasets/svannie678/red_team_repo_social_bias_prompts", - harm_categories=item.get("categorization", []), - metadata={ + # Initialize an empty list to store the seed prompts + seed_prompts = [] + + for group_id, item in enumerate(data["train"]): + + # This dataset contains 3 prompt types: "Single Prompt", "Multi Turn" and "Multi Turn, Single Prompt" + # We're only checking for "Multi Turn" because the "Multi Turn, Single Prompt" + # prompt types inside this dataset will be difficult to parse + prompt_type = item.get("prompt_type") + + if prompt_type is None: + continue + + # Dictionary of metadata for the current prompt + prompt_metadata = { + **common_metadata, + "harm_categories": item.get("categorization", []), + "metadata": { "organization": item.get("organization", ""), - "prompt_type": item.get("prompt_type", ""), + "prompt_type": prompt_type, "prompt_instruction": item.get("prompt_instruction", ""), "explanation": item.get("explanation", ""), "ai_response": item.get("ai_response", ""), }, - ) - for item in filtered_data - ] + } + + if prompt_type in ["Multi Turn"]: + try: + # Safely parse the user prompts, remove the unwanted ones such as "assistant" and "system" + user_prompts = [turn["body"] for turn in ast.literal_eval(item.get("prompt", "")) if turn["role"].startswith("user")] + + for i, user_prompt in enumerate(user_prompts): + seed_prompts.append( + SeedPrompt( + value=user_prompt, + data_type="text", + prompt_group_id=group_id, + sequence=i, + **prompt_metadata + ) + ) + except Exception as e: + print(f"Error processing Multi-Turn Prompt: {e}") + else: + # Clean up non-"Multi Turn" prompts that contain unwanted lines of text + cleaned_value = item.get("prompt", "").replace("### Response:", "").replace("### Instruction:", "").strip() + + seed_prompts.append( + SeedPrompt( + value=cleaned_value, + data_type="text", + **prompt_metadata + ) + ) seed_prompt_dataset = SeedPromptDataset(prompts=seed_prompts) return seed_prompt_dataset