Skip to content

Commit 22a63aa

Browse files
authored
Adding tool input and output guardrails (#1792)
- This PR was started from [PR 1606: Tool Guardrails](#1606) - It adds input and output guardrails at the tool level which can trigger `ToolInputGuardrailTripwireTriggered` and `ToolOutputGuardrailTripwireTriggered` exceptions - It includes updated documentation, a runnable example, and unit tests - `make check` and unit tests all pass ## Edits since last review: - Extracted nested tool running logic in `_run_impl.py` - Added rejecting tool call or tool call output and returning a message to the model (rather than only raising an exception) - Added the tool guardrail results to the `RunResult` - Removed docs
1 parent c02d863 commit 22a63aa

File tree

12 files changed

+1309
-21
lines changed

12 files changed

+1309
-21
lines changed

docs/guardrails.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,4 @@ async def main():
151151
1. This is the actual agent's output type.
152152
2. This is the guardrail's output type.
153153
3. This is the guardrail function that receives the agent's output, and returns the result.
154-
4. This is the actual agent that defines the workflow.
154+
4. This is the actual agent that defines the workflow.

examples/basic/tool_guardrails.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import asyncio
2+
import json
3+
4+
from agents import (
5+
Agent,
6+
Runner,
7+
ToolGuardrailFunctionOutput,
8+
ToolInputGuardrailData,
9+
ToolOutputGuardrailData,
10+
ToolOutputGuardrailTripwireTriggered,
11+
function_tool,
12+
tool_input_guardrail,
13+
tool_output_guardrail,
14+
)
15+
16+
17+
@function_tool
18+
def send_email(to: str, subject: str, body: str) -> str:
19+
"""Send an email to the specified recipient."""
20+
return f"Email sent to {to} with subject '{subject}'"
21+
22+
23+
@function_tool
24+
def get_user_data(user_id: str) -> dict[str, str]:
25+
"""Get user data by ID."""
26+
# Simulate returning sensitive data
27+
return {
28+
"user_id": user_id,
29+
"name": "John Doe",
30+
"email": "[email protected]",
31+
"ssn": "123-45-6789", # Sensitive data that should be blocked!
32+
"phone": "555-1234",
33+
}
34+
35+
36+
@function_tool
37+
def get_contact_info(user_id: str) -> dict[str, str]:
38+
"""Get contact info by ID."""
39+
return {
40+
"user_id": user_id,
41+
"name": "Jane Smith",
42+
"email": "[email protected]",
43+
"phone": "555-1234",
44+
}
45+
46+
47+
@tool_input_guardrail
48+
def reject_sensitive_words(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput:
49+
"""Reject tool calls that contain sensitive words in arguments."""
50+
try:
51+
args = json.loads(data.context.tool_arguments) if data.context.tool_arguments else {}
52+
except json.JSONDecodeError:
53+
return ToolGuardrailFunctionOutput(output_info="Invalid JSON arguments")
54+
55+
# Check for suspicious content
56+
sensitive_words = [
57+
"password",
58+
"hack",
59+
"exploit",
60+
"malware",
61+
"ACME",
62+
]
63+
for key, value in args.items():
64+
value_str = str(value).lower()
65+
for word in sensitive_words:
66+
if word.lower() in value_str:
67+
# Reject tool call and inform the model the function was not called
68+
return ToolGuardrailFunctionOutput.reject_content(
69+
message=f"🚨 Tool call blocked: contains '{word}'",
70+
output_info={"blocked_word": word, "argument": key},
71+
)
72+
73+
return ToolGuardrailFunctionOutput(output_info="Input validated")
74+
75+
76+
@tool_output_guardrail
77+
def block_sensitive_output(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput:
78+
"""Block tool outputs that contain sensitive data."""
79+
output_str = str(data.output).lower()
80+
81+
# Check for sensitive data patterns
82+
if "ssn" in output_str or "123-45-6789" in output_str:
83+
# Use raise_exception to halt execution completely for sensitive data
84+
return ToolGuardrailFunctionOutput.raise_exception(
85+
output_info={"blocked_pattern": "SSN", "tool": data.context.tool_name},
86+
)
87+
88+
return ToolGuardrailFunctionOutput(output_info="Output validated")
89+
90+
91+
@tool_output_guardrail
92+
def reject_phone_numbers(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput:
93+
"""Reject function output containing phone numbers."""
94+
output_str = str(data.output)
95+
if "555-1234" in output_str:
96+
return ToolGuardrailFunctionOutput.reject_content(
97+
message="User data not retrieved as it contains a phone number which is restricted.",
98+
output_info={"redacted": "phone_number"},
99+
)
100+
return ToolGuardrailFunctionOutput(output_info="Phone number check passed")
101+
102+
103+
# Apply guardrails to tools
104+
send_email.tool_input_guardrails = [reject_sensitive_words]
105+
get_user_data.tool_output_guardrails = [block_sensitive_output]
106+
get_contact_info.tool_output_guardrails = [reject_phone_numbers]
107+
108+
agent = Agent(
109+
name="Secure Assistant",
110+
instructions="You are a helpful assistant with access to email and user data tools.",
111+
tools=[send_email, get_user_data, get_contact_info],
112+
)
113+
114+
115+
async def main():
116+
print("=== Tool Guardrails Example ===\n")
117+
118+
try:
119+
# Example 1: Normal operation - should work fine
120+
print("1. Normal email sending:")
121+
result = await Runner.run(agent, "Send a welcome email to [email protected]")
122+
print(f"✅ Successful tool execution: {result.final_output}\n")
123+
124+
# Example 2: Input guardrail triggers - function tool call is rejected but execution continues
125+
print("2. Attempting to send email with suspicious content:")
126+
result = await Runner.run(
127+
agent, "Send an email to [email protected] introducing the company ACME corp."
128+
)
129+
print(f"❌ Guardrail rejected function tool call: {result.final_output}\n")
130+
except Exception as e:
131+
print(f"Error: {e}\n")
132+
133+
try:
134+
# Example 3: Output guardrail triggers - should raise exception for sensitive data
135+
print("3. Attempting to get user data (contains SSN). Execution blocked:")
136+
result = await Runner.run(agent, "Get the data for user ID user123")
137+
print(f"✅ Successful tool execution: {result.final_output}\n")
138+
except ToolOutputGuardrailTripwireTriggered as e:
139+
print("🚨 Output guardrail triggered: Execution halted for sensitive data")
140+
print(f"Details: {e.output.output_info}\n")
141+
142+
try:
143+
# Example 4: Output guardrail triggers - reject returning function tool output but continue execution
144+
print("4. Rejecting function tool output containing phone numbers:")
145+
result = await Runner.run(agent, "Get contact info for user456")
146+
print(f"❌ Guardrail rejected function tool output: {result.final_output}\n")
147+
except Exception as e:
148+
print(f"Error: {e}\n")
149+
150+
151+
if __name__ == "__main__":
152+
asyncio.run(main())
153+
154+
"""
155+
Example output:
156+
157+
=== Tool Guardrails Example ===
158+
159+
1. Normal email sending:
160+
✅ Successful tool execution: I've sent a welcome email to [email protected] with an appropriate subject and greeting message.
161+
162+
2. Attempting to send email with suspicious content:
163+
❌ Guardrail rejected function tool call: I'm unable to send the email as mentioning ACME Corp. is restricted.
164+
165+
3. Attempting to get user data (contains SSN). Execution blocked:
166+
🚨 Output guardrail triggered: Execution halted for sensitive data
167+
Details: {'blocked_pattern': 'SSN', 'tool': 'get_user_data'}
168+
169+
4. Rejecting function tool output containing sensitive data:
170+
❌ Guardrail rejected function tool output: I'm unable to retrieve the contact info for user456 because it contains restricted information.
171+
"""

src/agents/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
ModelBehaviorError,
2222
OutputGuardrailTripwireTriggered,
2323
RunErrorDetails,
24+
ToolInputGuardrailTripwireTriggered,
25+
ToolOutputGuardrailTripwireTriggered,
2426
UserError,
2527
)
2628
from .guardrail import (
@@ -83,6 +85,17 @@
8385
default_tool_error_function,
8486
function_tool,
8587
)
88+
from .tool_guardrails import (
89+
ToolGuardrailFunctionOutput,
90+
ToolInputGuardrail,
91+
ToolInputGuardrailData,
92+
ToolInputGuardrailResult,
93+
ToolOutputGuardrail,
94+
ToolOutputGuardrailData,
95+
ToolOutputGuardrailResult,
96+
tool_input_guardrail,
97+
tool_output_guardrail,
98+
)
8699
from .tracing import (
87100
AgentSpanData,
88101
CustomSpanData,
@@ -191,6 +204,8 @@ def enable_verbose_stdout_logging():
191204
"AgentsException",
192205
"InputGuardrailTripwireTriggered",
193206
"OutputGuardrailTripwireTriggered",
207+
"ToolInputGuardrailTripwireTriggered",
208+
"ToolOutputGuardrailTripwireTriggered",
194209
"DynamicPromptFunction",
195210
"GenerateDynamicPromptData",
196211
"Prompt",
@@ -204,6 +219,15 @@ def enable_verbose_stdout_logging():
204219
"GuardrailFunctionOutput",
205220
"input_guardrail",
206221
"output_guardrail",
222+
"ToolInputGuardrail",
223+
"ToolOutputGuardrail",
224+
"ToolGuardrailFunctionOutput",
225+
"ToolInputGuardrailData",
226+
"ToolInputGuardrailResult",
227+
"ToolOutputGuardrailData",
228+
"ToolOutputGuardrailResult",
229+
"tool_input_guardrail",
230+
"tool_output_guardrail",
207231
"handoff",
208232
"Handoff",
209233
"HandoffInputData",

0 commit comments

Comments
 (0)