Skip to content

Commit 2bc7cd6

Browse files
committed
Adding output recipe, cleaning up some imports
1 parent 7ee9f25 commit 2bc7cd6

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

litellm/proxy/guardrails/guardrail_hooks/pangea.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
# litellm/proxy/guardrails/guardrail_hooks/pangea.py
22
import os
3-
import sys
4-
5-
# Adds the parent directory to the system path to allow importing litellm modules
6-
sys.path.insert(
7-
0, os.path.abspath("../../..")
8-
)
9-
import json
103
from typing import Any, List, Literal, Optional, Union
114

125
from fastapi import HTTPException
136

14-
import litellm
157
from litellm._logging import verbose_proxy_logger
168
from litellm.integrations.custom_guardrail import (
179
CustomGuardrail,
@@ -50,7 +42,8 @@ class PangeaHandler(CustomGuardrail):
5042
def __init__(
5143
self,
5244
guardrail_name: str,
53-
pangea_recipe: str,
45+
pangea_input_recipe: Optional[str] = None,
46+
pangea_output_recipe: Optional[str] = None,
5447
api_key: Optional[str] = None,
5548
api_base: Optional[str] = None,
5649
**kwargs,
@@ -80,20 +73,22 @@ def __init__(
8073
or os.environ.get("PANGEA_API_BASE")
8174
or "https://ai-guard.aws.us.pangea.cloud"
8275
)
83-
self.pangea_recipe = pangea_recipe
76+
self.pangea_input_recipe = pangea_input_recipe
77+
self.pangea_output_recipe = pangea_output_recipe
8478
self.guardrail_endpoint = f"{self.api_base}/v1/text/guard"
8579

8680
# Pass relevant kwargs to the parent class
8781
super().__init__(guardrail_name=guardrail_name, **kwargs)
8882
verbose_proxy_logger.info(
89-
f"Initialized Pangea Guardrail: name={guardrail_name}, recipe={pangea_recipe}, api_base={self.api_base}"
83+
f"Initialized Pangea Guardrail: name={guardrail_name}, recipe={pangea_input_recipe}, api_base={self.api_base}"
9084
)
9185

9286
def _prepare_payload(
9387
self,
9488
messages: Optional[List[AllMessageValues]] = None,
9589
text_input: Optional[str] = None,
9690
request_data: Optional[dict] = None,
91+
recipe: Optional[str] = None,
9792
) -> dict:
9893
"""
9994
Prepares the payload for the Pangea AI Guard API request.
@@ -107,9 +102,12 @@ def _prepare_payload(
107102
dict: The payload dictionary for the API request.
108103
"""
109104
payload: dict[str, Any] = {
110-
"recipe": self.pangea_recipe,
111105
"debug": False, # Or make this configurable if needed
112106
}
107+
108+
if recipe:
109+
payload["recipe"] = recipe
110+
113111
if messages:
114112
# Ensure messages are in the format Pangea expects (list of dicts with 'role' and 'content')
115113
payload["messages"] = [
@@ -253,7 +251,7 @@ async def async_moderation_hook(
253251

254252
try:
255253
payload = self._prepare_payload(
256-
messages=messages, text_input=text_input, request_data=data
254+
messages=messages, text_input=text_input, request_data=data, recipe=self.pangea_input_recipe
257255
)
258256
await self._call_pangea_guard(
259257
payload=payload, request_data=data, hook_name="moderation_hook"
@@ -303,7 +301,7 @@ async def async_post_call_success_hook(
303301
try:
304302
# Scan only the output text in the post-call hook
305303
payload = self._prepare_payload(
306-
text_input=response_str, request_data=data
304+
text_input=response_str, request_data=data, recipe=self.pangea_output_recipe
307305
)
308306
await self._call_pangea_guard(
309307
payload=payload,
@@ -321,4 +319,4 @@ async def async_post_call_success_hook(
321319
"error": f"Error preparing Pangea payload for response: {ve}",
322320
"guardrail_name": self.guardrail_name,
323321
},
324-
)
322+
)

litellm/proxy/guardrails/guardrail_initializers.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,10 @@ def initialize_pangea(litellm_params, guardrail):
121121

122122
_pangea_callback = PangeaHandler(
123123
guardrail_name=guardrail["guardrail_name"],
124-
pangea_recipe=litellm_params["pangea_recipe"],
124+
pangea_input_recipe=litellm_params["pangea_input_recipe"],
125+
pangea_output_recipe=litellm_params["pangea_output_recipe"],
125126
api_base=litellm_params["api_base"],
126127
api_key=litellm_params["api_key"],
127128
default_on=litellm_params["default_on"],
128129
)
129-
litellm.logging_callback_manager.add_litellm_callback(_pangea_callback)
130+
litellm.logging_callback_manager.add_litellm_callback(_pangea_callback)

litellm/types/guardrails.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ class LitellmParams(TypedDict):
115115
] # will mask response content if guardrail makes any changes
116116

117117
# pangea params
118-
pangea_recipe: Optional[str]
118+
pangea_input_recipe: Optional[str]
119+
pangea_output_recipe: Optional[str]
119120

120121
class Guardrail(TypedDict, total=False):
121122
guardrail_name: str

0 commit comments

Comments
 (0)