From 3d0543cca8dc252662a86ad3ae9f7722b82150dd Mon Sep 17 00:00:00 2001 From: Nicole Pellicena Date: Tue, 25 Feb 2025 14:46:30 +0000 Subject: [PATCH] FEAT: include scored_prompt_id in orchestrator_identifier of the system prompt (#725) --- pyrit/score/insecure_code_scorer.py | 1 + pyrit/score/scorer.py | 6 +- pyrit/score/self_ask_category_scorer.py | 1 + pyrit/score/self_ask_refusal_scorer.py | 1 + pyrit/score/self_ask_true_false_scorer.py | 1 + tests/unit/score/test_scorer.py | 68 +++++++++++++++++++++++ 6 files changed, 77 insertions(+), 1 deletion(-) diff --git a/pyrit/score/insecure_code_scorer.py b/pyrit/score/insecure_code_scorer.py index 0bc7f0b4a..20d1bdfad 100644 --- a/pyrit/score/insecure_code_scorer.py +++ b/pyrit/score/insecure_code_scorer.py @@ -64,6 +64,7 @@ async def score_async(self, request_response: PromptRequestPiece, *, task: Optio scored_prompt_id=request_response.id, category=self._harm_category, task=task, + orchestrator_identifier=request_response.orchestrator_identifier, ) # Modify the UnvalidatedScore parsing to check for 'score_value' diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index cdb515cc3..46779a234 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -218,6 +218,7 @@ async def _score_value_with_llm( scored_prompt_id: str, category: str = None, task: str = None, + orchestrator_identifier: dict[str, str] = None, ) -> UnvalidatedScore: """ Sends a request to a target, and takes care of retries. @@ -242,10 +243,13 @@ async def _score_value_with_llm( conversation_id = str(uuid.uuid4()) + if orchestrator_identifier: + orchestrator_identifier["scored_prompt_id"] = str(scored_prompt_id) + prompt_target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, - orchestrator_identifier=None, + orchestrator_identifier=orchestrator_identifier, ) prompt_metadata = {"response_format": "json"} scorer_llm_request = PromptRequestResponse( diff --git a/pyrit/score/self_ask_category_scorer.py b/pyrit/score/self_ask_category_scorer.py index b2a7c91c2..7aedac3c6 100644 --- a/pyrit/score/self_ask_category_scorer.py +++ b/pyrit/score/self_ask_category_scorer.py @@ -106,6 +106,7 @@ async def score_async(self, request_response: PromptRequestPiece, *, task: Optio prompt_request_data_type=request_response.converted_value_data_type, scored_prompt_id=request_response.id, task=task, + orchestrator_identifier=request_response.orchestrator_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value) diff --git a/pyrit/score/self_ask_refusal_scorer.py b/pyrit/score/self_ask_refusal_scorer.py index 4d129b250..33442ecfc 100644 --- a/pyrit/score/self_ask_refusal_scorer.py +++ b/pyrit/score/self_ask_refusal_scorer.py @@ -98,6 +98,7 @@ async def score_async(self, request_response: PromptRequestPiece, *, task: Optio scored_prompt_id=request_response.id, category=self._score_category, task=task, + orchestrator_identifier=request_response.orchestrator_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value) diff --git a/pyrit/score/self_ask_true_false_scorer.py b/pyrit/score/self_ask_true_false_scorer.py index 08098498a..19090d969 100644 --- a/pyrit/score/self_ask_true_false_scorer.py +++ b/pyrit/score/self_ask_true_false_scorer.py @@ -125,6 +125,7 @@ async def score_async(self, request_response: PromptRequestPiece, *, task: Optio scored_prompt_id=request_response.id, category=self._score_category, task=task, + orchestrator_identifier=request_response.orchestrator_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value) diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 4213163cf..00c755cb7 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -110,6 +110,74 @@ async def test_scorer_score_value_with_llm_exception_display_prompt_id(): ) +@pytest.mark.asyncio +async def test_scorer_score_value_with_llm_use_provided_orchestrator_identifier(good_json): + scorer = MockScorer() + scorer.scorer_type = "true_false" + + prompt_response = PromptRequestResponse( + request_pieces=[PromptRequestPiece(role="assistant", original_value=good_json)] + ) + chat_target = MagicMock(PromptChatTarget) + chat_target.send_prompt_async = AsyncMock(return_value=prompt_response) + chat_target.set_system_prompt = AsyncMock() + + expected_system_prompt = "system_prompt" + expected_orchestrator_id = "orchestrator_id" + expected_scored_prompt_id = "123" + + await scorer._score_value_with_llm( + prompt_target=chat_target, + system_prompt=expected_system_prompt, + prompt_request_value="prompt_request_value", + prompt_request_data_type="text", + scored_prompt_id=expected_scored_prompt_id, + category="category", + task="task", + orchestrator_identifier={"id": expected_orchestrator_id}, + ) + + chat_target.set_system_prompt.assert_called_once() + + _, set_sys_prompt_args = chat_target.set_system_prompt.call_args + assert set_sys_prompt_args["system_prompt"] == expected_system_prompt + assert isinstance(set_sys_prompt_args["conversation_id"], str) + assert set_sys_prompt_args["orchestrator_identifier"]["id"] == expected_orchestrator_id + assert set_sys_prompt_args["orchestrator_identifier"]["scored_prompt_id"] == expected_scored_prompt_id + + +@pytest.mark.asyncio +async def test_scorer_score_value_with_llm_does_not_add_score_prompt_id_for_empty_orchestrator_identifier(good_json): + scorer = MockScorer() + scorer.scorer_type = "true_false" + + prompt_response = PromptRequestResponse( + request_pieces=[PromptRequestPiece(role="assistant", original_value=good_json)] + ) + chat_target = MagicMock(PromptChatTarget) + chat_target.send_prompt_async = AsyncMock(return_value=prompt_response) + chat_target.set_system_prompt = AsyncMock() + + expected_system_prompt = "system_prompt" + + await scorer._score_value_with_llm( + prompt_target=chat_target, + system_prompt=expected_system_prompt, + prompt_request_value="prompt_request_value", + prompt_request_data_type="text", + scored_prompt_id="123", + category="category", + task="task", + ) + + chat_target.set_system_prompt.assert_called_once() + + _, set_sys_prompt_args = chat_target.set_system_prompt.call_args + assert set_sys_prompt_args["system_prompt"] == expected_system_prompt + assert isinstance(set_sys_prompt_args["conversation_id"], str) + assert not set_sys_prompt_args["orchestrator_identifier"] + + @pytest.mark.asyncio async def test_scorer_send_chat_target_async_good_response(good_json):