Skip to content

Commit 8e6f3d6

Browse files
committed
feat: add regex-based tool_name/tool_type matching for tool-permission
1 parent be71290 commit 8e6f3d6

File tree

5 files changed

+257
-110
lines changed

5 files changed

+257
-110
lines changed

docs/my-website/docs/proxy/guardrails/tool_permission.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,17 @@ guardrails:
2121
tool_name: "Bash"
2222
decision: "allow"
2323
- id: "allow_github_mcp"
24-
tool_name: "mcp__github_*"
24+
tool_name: "^mcp__github_.*$"
2525
decision: "allow"
2626
- id: "allow_aws_documentation"
27-
tool_name: "mcp__aws-documentation_*_documentation"
27+
tool_name: "^mcp__aws-documentation_.*_documentation$"
2828
decision: "allow"
2929
- id: "deny_read_commands"
3030
tool_name: "Read"
31-
decision: "Deny"
31+
decision: "deny"
3232
- id: "mail-domain"
33-
tool_name: "send_email"
33+
tool_name: "^send_email$"
34+
tool_type: "^function$"
3435
decision: "allow"
3536
allowed_param_patterns:
3637
"to[]": "^.+@berri\\.ai$"
@@ -44,7 +45,8 @@ guardrails:
4445
4546
```yaml
4647
- id: "unique_rule_id" # Unique identifier for the rule
47-
tool_name: "pattern" # Tool name or pattern to match
48+
tool_name: "^regex$" # Regex for tool name (optional, at least one of name/type required)
49+
tool_type: "^function$" # Regex for tool type (optional)
4850
decision: "allow" # "allow" or "deny"
4951
allowed_param_patterns: # Optional - regex map for argument paths (dot + [] notation)
5052
"path.to[].field": "^regex$"
@@ -103,6 +105,7 @@ litellm --config config.yaml --port 4000
103105
<Tabs>
104106
<TabItem value="block" label="Block Request">
105107

108+
106109
**Block request (`on_disallowed_action: block`)**
107110

108111
```bash

litellm/proxy/guardrails/guardrail_hooks/tool_permission.py

Lines changed: 95 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161

6262
self.rules: List[ToolPermissionRule] = []
6363
self._compiled_rule_patterns: Dict[str, Dict[str, re.Pattern]] = {}
64+
self._compiled_rule_targets: Dict[str, Dict[str, Optional[re.Pattern]]] = {}
6465
if rules:
6566
for rule_item in rules:
6667
if isinstance(rule_item, ToolPermissionRule):
@@ -69,6 +70,30 @@ def __init__(
6970
rule = ToolPermissionRule(**rule_item)
7071
self.rules.append(rule)
7172

73+
compiled_target_patterns: Dict[str, Optional[re.Pattern]] = {
74+
"tool_name": None,
75+
"tool_type": None,
76+
}
77+
if rule.tool_name is not None:
78+
try:
79+
compiled_target_patterns["tool_name"] = re.compile(
80+
rule.tool_name
81+
)
82+
except re.error as exc:
83+
raise ValueError(
84+
f"Invalid regex for tool_name in rule '{rule.id}': {exc}"
85+
) from exc
86+
if rule.tool_type is not None:
87+
try:
88+
compiled_target_patterns["tool_type"] = re.compile(
89+
rule.tool_type
90+
)
91+
except re.error as exc:
92+
raise ValueError(
93+
f"Invalid regex for tool_type in rule '{rule.id}': {exc}"
94+
) from exc
95+
self._compiled_rule_targets[rule.id] = compiled_target_patterns
96+
7297
if rule.allowed_param_patterns:
7398
compiled_patterns: Dict[str, re.Pattern] = {}
7499
for path, pattern in rule.allowed_param_patterns.items():
@@ -99,59 +124,75 @@ def get_config_model():
99124

100125
return ToolPermissionGuardrailConfigModel
101126

102-
def _matches_pattern(self, tool_name: str, pattern: str) -> bool:
103-
"""
104-
Check if a tool name matches a pattern
105-
106-
Supports patterns like:
107-
- "Bash" - exact match
108-
- "mcp__*" - prefix pattern (matches names starting wich "mcp__")
109-
- "*_read" - suffix wildcard (matches names ending with "_read")
110-
- "mcp__github_*_read" - infix wildcard (matches names like "mcp__github_mark_all_notifications_read")
111-
112-
Args:
113-
tool_name: Name of the tool to check
114-
pattern: Pattern to match against
115-
116-
Returns:
117-
True if the tool name matches the pattern
118-
"""
119-
# Handle exact matches
120-
if tool_name == pattern:
127+
def _matches_regex(
128+
self, pattern: Optional[re.Pattern], value: Optional[str]
129+
) -> bool:
130+
if pattern is None:
121131
return True
132+
if value is None:
133+
return False
134+
return bool(pattern.fullmatch(value))
122135

123-
if "*" in pattern:
124-
# Escape regex special chars except '*'
125-
escaped_pattern = re.escape(pattern)
126-
# Turn \* into .*
127-
regex_pattern = escaped_pattern.replace(r"\*", ".*")
128-
return bool(re.fullmatch(regex_pattern, tool_name))
136+
def _rule_matches_tool(
137+
self,
138+
rule: ToolPermissionRule,
139+
*,
140+
tool_name: Optional[str],
141+
tool_type: Optional[str] = None,
142+
) -> tuple[bool, bool]:
143+
target_patterns = self._compiled_rule_targets.get(rule.id, {})
144+
name_pattern = target_patterns.get("tool_name")
145+
type_pattern = target_patterns.get("tool_type")
146+
147+
name_required = rule.tool_name is not None
148+
type_required = rule.tool_type is not None
149+
150+
name_matched = (
151+
self._matches_regex(name_pattern, tool_name) if name_required else True
152+
)
153+
type_matched = (
154+
self._matches_regex(type_pattern, tool_type) if type_required else True
155+
)
156+
157+
overall_match = name_matched and type_matched
158+
should_check_params = name_required and name_matched
129159

130-
return False
160+
return overall_match, should_check_params
131161

132162
def _check_tool_permission(
133-
self, tool_name: str
163+
self,
164+
tool_name: Optional[str],
165+
tool_type: Optional[str] = None,
134166
) -> tuple[bool, Optional[str], Optional[str]]:
135167
"""
136168
Check if a tool is allowed based on the configured rules
137169
138170
Args:
139171
tool_name: Name of the tool to check
172+
tool_type: Type of the tool to check
140173
141174
Returns:
142175
Tuple of (is_allowed, rule_id, message)
143176
"""
144-
verbose_proxy_logger.debug(f"Checking permission for tool: {tool_name}")
177+
verbose_proxy_logger.debug(
178+
f"Checking permission for tool: {tool_name or tool_type}"
179+
)
145180

146181
# Check each rule in order
147182
for rule in self.rules:
148-
if self._matches_pattern(tool_name, rule.tool_name):
183+
matches, _ = self._rule_matches_tool(
184+
rule,
185+
tool_name=tool_name,
186+
tool_type=tool_type,
187+
)
188+
if matches:
149189
is_allowed = rule.decision == "allow"
150-
default_message = f"Tool '{tool_name}' {'allowed' if is_allowed else 'denied'} by rule '{rule.id}'"
190+
tool_identifier = tool_name or tool_type or "unknown_tool"
191+
default_message = f"Tool '{tool_identifier}' {'allowed' if is_allowed else 'denied'} by rule '{rule.id}'"
151192
message = self.render_violation_message(
152193
default=default_message,
153194
context={
154-
"tool_name": tool_name,
195+
"tool_name": tool_name or tool_identifier,
155196
"rule_id": rule.id,
156197
},
157198
)
@@ -160,11 +201,12 @@ def _check_tool_permission(
160201

161202
# No rule matched, use default action
162203
is_allowed = self.default_action == "allow"
163-
default_message = f"Tool '{tool_name}' {'allowed' if is_allowed else 'denied'} by default action"
204+
tool_identifier = tool_name or tool_type or "unknown_tool"
205+
default_message = f"Tool '{tool_identifier}' {'allowed' if is_allowed else 'denied'} by default action"
164206
message = self.render_violation_message(
165207
default=default_message,
166208
context={
167-
"tool_name": tool_name,
209+
"tool_name": tool_name or tool_identifier,
168210
"rule_id": None,
169211
},
170212
)
@@ -222,7 +264,7 @@ def _patterns_match_for_rule(
222264
*,
223265
arguments: Dict[str, Any],
224266
rule: ToolPermissionRule,
225-
tool_name: str,
267+
tool_name: Optional[str],
226268
) -> tuple[bool, Optional[str]]:
227269
compiled_patterns = self._compiled_rule_patterns.get(rule.id)
228270
if not compiled_patterns:
@@ -243,7 +285,7 @@ def _patterns_match_for_rule(
243285
return (
244286
False,
245287
f"Value '{raw_value}' for path '{path}' does not match allowed pattern"
246-
f" '{compiled_pattern.pattern}' for tool '{tool_name}'",
288+
f" '{compiled_pattern.pattern}' for tool '{tool_name or 'unknown_tool'}'",
247289
)
248290

249291
return True, None
@@ -252,19 +294,27 @@ def _get_permission_for_tool_call(
252294
self, tool_call: ChatCompletionMessageToolCall
253295
) -> tuple[bool, Optional[str], Optional[str]]:
254296
tool_name = tool_call.function.name if tool_call.function else None
255-
if not tool_name:
297+
tool_type = getattr(tool_call, "type", None)
298+
if not tool_name and not tool_type:
256299
return self.default_action == "allow", None, None
257300

301+
tool_identifier = tool_name or tool_type or "unknown_tool"
302+
258303
last_pattern_failure_msg: Optional[str] = None
259304

260305
for rule in self.rules:
261-
if not self._matches_pattern(tool_name, rule.tool_name):
306+
matches, should_check_params = self._rule_matches_tool(
307+
rule,
308+
tool_name=tool_name,
309+
tool_type=tool_type,
310+
)
311+
if not matches:
262312
continue
263313

264-
if rule.allowed_param_patterns:
314+
if rule.allowed_param_patterns and should_check_params:
265315
arguments = self._parse_tool_call_arguments(tool_call)
266316
if not arguments:
267-
last_pattern_failure_msg = f"Tool '{tool_name}' is missing arguments required by rule '{rule.id}'"
317+
last_pattern_failure_msg = f"Tool '{tool_identifier}' is missing arguments required by rule '{rule.id}'"
268318
continue
269319

270320
patterns_match, failure_message = self._patterns_match_for_rule(
@@ -277,22 +327,22 @@ def _get_permission_for_tool_call(
277327
continue
278328

279329
is_allowed = rule.decision == "allow"
280-
default_message = f"Tool '{tool_name}' {'allowed' if is_allowed else 'denied'} by rule '{rule.id}'"
330+
default_message = f"Tool '{tool_identifier}' {'allowed' if is_allowed else 'denied'} by rule '{rule.id}'"
281331
message = self.render_violation_message(
282332
default=default_message,
283-
context={"tool_name": tool_name, "rule_id": rule.id},
333+
context={"tool_name": tool_identifier, "rule_id": rule.id},
284334
)
285335
return is_allowed, rule.id, message
286336

287337
is_allowed = self.default_action == "allow"
288338
default_message = (
289339
last_pattern_failure_msg
290340
if (last_pattern_failure_msg and not is_allowed)
291-
else f"Tool '{tool_name}' {'allowed' if is_allowed else 'denied'} by default action"
341+
else f"Tool '{tool_identifier}' {'allowed' if is_allowed else 'denied'} by default action"
292342
)
293343
message = self.render_violation_message(
294344
default=default_message,
295-
context={"tool_name": tool_name, "rule_id": None},
345+
context={"tool_name": tool_identifier, "rule_id": None},
296346
)
297347
return is_allowed, None, message
298348

@@ -474,8 +524,9 @@ async def async_pre_call_hook(
474524
if tool["type"] != "function":
475525
continue
476526
tool_name: str = tool["function"]["name"]
527+
tool_type: Optional[str] = tool.get("type")
477528

478-
is_allowed, _, message = self._check_tool_permission(tool_name)
529+
is_allowed, _, message = self._check_tool_permission(tool_name, tool_type)
479530

480531
if not is_allowed and message is not None:
481532
verbose_proxy_logger.warning(f"Tool Permission Guardrail: {message}")

litellm/types/proxy/guardrails/guardrail_hooks/tool_permission.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Tool Permission Guardrail Type Definitions
22
from typing import Dict, List, Literal, Optional
33

4-
from pydantic import BaseModel, Field
4+
from pydantic import BaseModel, Field, field_validator, model_validator
55

66
from .base import GuardrailConfigModel
77

@@ -12,8 +12,13 @@ class ToolPermissionRule(BaseModel):
1212
"""
1313

1414
id: str = Field(description="Unique identifier for the rule")
15-
tool_name: str = Field(
16-
description="Tool name or pattern (e.g., 'Bash', 'mcp__github_*', 'mcp__github_*_read', '*_read')"
15+
tool_name: Optional[str] = Field(
16+
default=None,
17+
description="Regex pattern applied to the tool's function name",
18+
)
19+
tool_type: Optional[str] = Field(
20+
default=None,
21+
description="Regex pattern applied to the tool type (e.g., function)",
1722
)
1823
decision: Literal["allow", "deny"] = Field(
1924
description="Whether to allow or deny this tool usage"
@@ -23,6 +28,26 @@ class ToolPermissionRule(BaseModel):
2328
description="Optional regex map enforcing nested parameter values using dot/[] paths",
2429
)
2530

31+
@field_validator("tool_name", "tool_type", mode="before")
32+
@classmethod
33+
def _blank_to_none(cls, value: Optional[str]) -> Optional[str]:
34+
if value is None:
35+
return None
36+
if isinstance(value, str):
37+
stripped = value.strip()
38+
if not stripped:
39+
return None
40+
return stripped
41+
return value
42+
43+
@model_validator(mode="after")
44+
def _ensure_target_present(self):
45+
if self.tool_name is None and self.tool_type is None:
46+
raise ValueError(
47+
"Each rule must specify at least a tool_name or tool_type regex"
48+
)
49+
return self
50+
2651

2752
class ToolResult(BaseModel):
2853
"""
@@ -52,7 +77,7 @@ class ToolPermissionGuardrailConfigModel(GuardrailConfigModel):
5277

5378
rules: Optional[List[ToolPermissionRule]] = Field(
5479
default=None,
55-
description="Ordered allow/deny rules. Patterns support * wildcards and optional regex constraints on tool arguments.",
80+
description="Ordered allow/deny rules. Patterns use regex for tool names/types and optional regex constraints on tool arguments.",
5681
)
5782
default_action: Literal["allow", "deny"] = Field(
5883
default="deny", description="Fallback decision when no rule matches"

0 commit comments

Comments
 (0)