|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +import random |
| 5 | +import re |
| 6 | +import string |
| 7 | +from typing import List, Optional |
| 8 | + |
| 9 | +from pyrit.models import PromptDataType |
| 10 | +from pyrit.prompt_converter import ConverterResult, PromptConverter |
| 11 | + |
| 12 | + |
| 13 | +class InsertPunctuationConverter(PromptConverter): |
| 14 | + """ |
| 15 | + Inserts punctuation into a prompt to test robustness. |
| 16 | + Punctuation insertion: inserting single punctuations in string.punctuation. |
| 17 | + Words in a prompt: a word does not contain any punctuation and space. |
| 18 | + "a1b2c3" is a word; "a1 2" are 2 words; "a1,b,3" are 3 words. |
| 19 | + """ |
| 20 | + |
| 21 | + default_punctuation_list = [",", ".", "!", "?", ":", ";", "-"] |
| 22 | + |
| 23 | + def __init__(self, word_swap_ratio: float = 0.2, between_words: bool = True) -> None: |
| 24 | + """ |
| 25 | + Initialize the converter with optional and word swap ratio. |
| 26 | + Args: |
| 27 | + word_swap_ratio (float): Percentage of words to perturb. Defaults to 0.2. |
| 28 | + between_words (bool): If True, insert punctuation only between words. |
| 29 | + If False, insert punctuation within words. Defaults to True. |
| 30 | + """ |
| 31 | + # Swap ratio cannot be 0 or larger than 1 |
| 32 | + if not 0 < word_swap_ratio <= 1: |
| 33 | + raise ValueError("word_swap_ratio must be between 0 to 1, as (0, 1].") |
| 34 | + |
| 35 | + self._word_swap_ratio = word_swap_ratio |
| 36 | + self._between_words = between_words |
| 37 | + |
| 38 | + def _is_valid_punctuation(self, punctuation_list: List[str]) -> bool: |
| 39 | + """ |
| 40 | + Check if all items in the list are valid punctuation characters in string.punctuation. |
| 41 | + Space, letters, numbers, double punctuations are all invalid. |
| 42 | + Args: |
| 43 | + punctuation_list (List[str]): List of punctuations to validate. |
| 44 | + Returns: |
| 45 | + bool: valid list and valid punctuations |
| 46 | + """ |
| 47 | + return all(str in string.punctuation for str in punctuation_list) |
| 48 | + |
| 49 | + async def convert_async( |
| 50 | + self, *, prompt: str, input_type: PromptDataType = "text", punctuation_list: Optional[List[str]] = None |
| 51 | + ) -> ConverterResult: |
| 52 | + """ |
| 53 | + Convert the given prompt by inserting punctuation. |
| 54 | + Args: |
| 55 | + prompt (str): The text to convert. |
| 56 | + input_type (PromptDataType): The type of input data. |
| 57 | + punctuation_list (Optional[List[str]]): List of punctuations to use for insertion. |
| 58 | + Returns: |
| 59 | + ConverterResult: A ConverterResult containing a interation of modified prompts. |
| 60 | + """ |
| 61 | + if not self.input_supported(input_type): |
| 62 | + raise ValueError("Input type not supported") |
| 63 | + |
| 64 | + # Initialize default punctuation list |
| 65 | + # If not specified, defaults to default_punctuation_list |
| 66 | + if punctuation_list is None: |
| 67 | + punctuation_list = self.default_punctuation_list |
| 68 | + elif not self._is_valid_punctuation(punctuation_list): |
| 69 | + raise ValueError( |
| 70 | + f"Invalid punctuations: {punctuation_list}." |
| 71 | + f" Only single characters from {string.punctuation} are allowed." |
| 72 | + ) |
| 73 | + |
| 74 | + modified_prompt = self._insert_punctuation(prompt, punctuation_list) |
| 75 | + return ConverterResult(output_text=modified_prompt, output_type="text") |
| 76 | + |
| 77 | + def _insert_punctuation(self, prompt: str, punctuation_list: List[str]) -> str: |
| 78 | + """ |
| 79 | + Insert punctuation into the prompt. |
| 80 | + Args: |
| 81 | + prompt (str): The text to modify. |
| 82 | + punctuation_list (List[str]): List of punctuations for insertion. |
| 83 | + Returns: |
| 84 | + str: The modified prompt with inserted punctuation from helper method. |
| 85 | + """ |
| 86 | + # Words list contains single spaces, single word without punctuations, single punctuations |
| 87 | + words = re.findall(r"\w+|[^\w\s]|\s", prompt) |
| 88 | + # Maintains indices for actual "words", i.e. letters and numbers not divided by punctuations |
| 89 | + word_indices = [i for i in range(0, len(words)) if not re.match(r"\W", words[i])] |
| 90 | + # Calculate the number of insertions |
| 91 | + num_insertions = max( |
| 92 | + 1, round(len(word_indices) * self._word_swap_ratio) |
| 93 | + ) # Ensure at least one punctuation is inserted |
| 94 | + |
| 95 | + # If there's no actual word without punctuations in the list, insert random punctuation at position 0 |
| 96 | + if not word_indices: |
| 97 | + return random.choice(punctuation_list) + prompt |
| 98 | + |
| 99 | + if self._between_words: |
| 100 | + return self._insert_between_words(words, word_indices, num_insertions, punctuation_list) |
| 101 | + else: |
| 102 | + return self._insert_within_words(prompt, num_insertions, punctuation_list) |
| 103 | + |
| 104 | + def _insert_between_words( |
| 105 | + self, words: List[str], word_indices: List[int], num_insertions: int, punctuation_list: List[str] |
| 106 | + ) -> str: |
| 107 | + """ |
| 108 | + Insert punctuation between words in the prompt. |
| 109 | + Args: |
| 110 | + words (List[str]): List of words and punctuations. |
| 111 | + word_indices (List[int]): Indices of the actual words without punctuations in words list. |
| 112 | + num_insertions (int): Number of punctuations to insert. |
| 113 | + punctuation_list (List[str]): punctuations for insertion. |
| 114 | +
|
| 115 | + Returns: |
| 116 | + str: The modified prompt with inserted punctuation. |
| 117 | + """ |
| 118 | + insert_indices = random.sample(word_indices, num_insertions) |
| 119 | + # Randomly choose num_insertions indices from actual word indices. |
| 120 | + INSERT_BEFORE = 0 |
| 121 | + INSERT_AFTER = 1 |
| 122 | + for index in insert_indices: |
| 123 | + if random.randint(INSERT_BEFORE, INSERT_AFTER) == INSERT_AFTER: |
| 124 | + words[index] += random.choice(punctuation_list) |
| 125 | + else: |
| 126 | + words[index] = random.choice(punctuation_list) + words[index] |
| 127 | + # Join the words list and return a modified prompt |
| 128 | + return "".join(words).strip() |
| 129 | + |
| 130 | + def _insert_within_words(self, prompt: str, num_insertions: int, punctuation_list: List[str]) -> str: |
| 131 | + """ |
| 132 | + Insert punctuation at any indices in the prompt, can insert into a word. |
| 133 | + Args: |
| 134 | + promp str: The prompt string |
| 135 | + num_insertions (int): Number of punctuations to insert. |
| 136 | + punctuation_list (List[str]): punctuations for insertion. |
| 137 | + Returns: |
| 138 | + str: The modified prompt with inserted punctuation. |
| 139 | + """ |
| 140 | + # List of chars in the prompt string |
| 141 | + prompt_list = list(prompt) |
| 142 | + # Store random indices of prompt_list into insert_indices |
| 143 | + # If the prompt has only 0 or 1 chars, insert at the end of the prompt |
| 144 | + insert_indices = ( |
| 145 | + [1] if len(prompt_list) <= num_insertions else random.sample(range(0, len(prompt_list) - 1), num_insertions) |
| 146 | + ) |
| 147 | + |
| 148 | + for index in insert_indices: |
| 149 | + # Insert into prompt_list at the insert_indices with random punctuation from the punctuation_list |
| 150 | + prompt_list.insert(index, random.choice(punctuation_list)) |
| 151 | + |
| 152 | + return "".join(prompt_list).strip() |
| 153 | + |
| 154 | + def input_supported(self, input_type: PromptDataType) -> bool: |
| 155 | + return input_type == "text" |
0 commit comments