Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/mcp/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ pub use orchestrator::{
};
pub use pool::{McpConnectionPool, PoolKey};
pub use reconnect::ReconnectionManager;
pub use session::{McpServerBinding, McpToolSession, DEFAULT_SERVER_LABEL};
pub use session::{McpServerBinding, McpToolExposureFilter, McpToolSession, DEFAULT_SERVER_LABEL};
244 changes: 222 additions & 22 deletions crates/mcp/src/core/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,62 @@ pub struct McpServerBinding {
pub label: String,
/// Internal key used to look up the server in the orchestrator.
pub server_key: String,
/// Optional per-server tool allowlist.
/// Optional per-server tool exposure filter.
///
/// When `Some`, only the listed tool names are exposed for this server.
/// When `Some`, only tools matching the filter are exposed for this server.
/// When `None`, all tools from the server are exposed.
pub allowed_tools: Option<Vec<String>>,
pub allowed_tools: Option<McpToolExposureFilter>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct McpToolExposureFilter {
pub tool_names: Option<Vec<String>>,
pub read_only: Option<bool>,
}

impl McpToolExposureFilter {
pub fn tool_names(tool_names: Vec<String>) -> Self {
Self {
tool_names: Some(tool_names),
read_only: None,
}
}

pub fn read_only(read_only: bool) -> Self {
Self {
tool_names: None,
read_only: Some(read_only),
}
}

pub fn new(tool_names: Option<Vec<String>>, read_only: Option<bool>) -> Self {
Self {
tool_names,
read_only,
}
}
}

struct PreparedMcpToolExposureFilter<'a> {
tool_names: Option<HashSet<&'a str>>,
read_only: Option<bool>,
}

impl<'a> PreparedMcpToolExposureFilter<'a> {
fn from_filter(filter: &'a McpToolExposureFilter) -> Self {
let tool_names = filter.tool_names.as_ref().map(|names| {
names
.iter()
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect()
});

Self {
tool_names,
read_only: filter.read_only,
}
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -120,26 +171,25 @@ impl<'a> McpToolSession<'a> {
let server_keys: Vec<String> = mcp_servers.iter().map(|b| b.server_key.clone()).collect();
let mut mcp_tools = Self::collect_visible_mcp_tools(orchestrator, &server_keys);

// Build per-server allowlists from bindings that specify allowed_tools.
let allowed_tools_by_server_key: HashMap<&str, HashSet<&str>> = mcp_servers
.iter()
.filter_map(|b| {
b.allowed_tools.as_ref().map(|tools| {
let set: HashSet<&str> = tools
.iter()
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
(b.server_key.as_str(), set)
// Build per-server exposure filters from bindings that specify allowed_tools.
let allowed_tools_by_server_key: HashMap<&str, PreparedMcpToolExposureFilter<'_>> =
mcp_servers
.iter()
.filter_map(|b| {
b.allowed_tools.as_ref().map(|filter| {
(
b.server_key.as_str(),
PreparedMcpToolExposureFilter::from_filter(filter),
)
})
})
})
.collect();
.collect();

if !allowed_tools_by_server_key.is_empty() {
mcp_tools.retain(|entry| {
match allowed_tools_by_server_key.get(Self::associated_server_key(entry)) {
None => true,
Some(allowed) => Self::matches_allowed_tool_name(entry, allowed),
Some(filter) => Self::matches_allowed_tool_filter(entry, filter),
}
});
}
Expand Down Expand Up @@ -830,6 +880,9 @@ impl<'a> McpToolSession<'a> {
for direct_entry in direct_tools {
if let Some(mut alias_entries) = aliases_by_target.remove(&direct_entry.qualified_name)
{
for alias_entry in &mut alias_entries {
alias_entry.annotations = direct_entry.annotations.clone();
}
visible_tools.append(&mut alias_entries);
} else {
visible_tools.push(direct_entry);
Expand All @@ -851,7 +904,20 @@ impl<'a> McpToolSession<'a> {
.unwrap_or_else(|| entry.server_key())
}

fn matches_allowed_tool_name(entry: &ToolEntry, allowed: &HashSet<&str>) -> bool {
fn matches_allowed_tool_filter(
entry: &ToolEntry,
filter: &PreparedMcpToolExposureFilter<'_>,
) -> bool {
if let Some(read_only) = filter.read_only {
if entry.annotations.read_only != read_only {
return false;
Comment thread
RohanSogani marked this conversation as resolved.
}
}

let Some(allowed) = filter.tool_names.as_ref() else {
return true;
};

allowed.contains(entry.tool_name())
|| entry
.alias_target
Expand Down Expand Up @@ -917,7 +983,7 @@ mod tests {
use serde_json::json;

use super::*;
use crate::core::config::Tool as McpTool;
use crate::{annotations::ToolAnnotations, core::config::Tool as McpTool};

#[test]
fn test_session_creation_keeps_servers() {
Expand Down Expand Up @@ -1338,7 +1404,9 @@ mod tests {
vec![McpServerBinding {
label: "mock".to_string(),
server_key: "server1".to_string(),
allowed_tools: Some(vec!["brave_web_search".to_string()]),
allowed_tools: Some(McpToolExposureFilter::tool_names(vec![
"brave_web_search".to_string()
])),
}],
"test-request",
);
Expand All @@ -1352,6 +1420,87 @@ mod tests {
assert_eq!(listed[0].tool_name(), "brave_web_search");
}

#[test]
fn test_allowed_tools_read_only_filter_uses_tool_annotations() {
let orchestrator = McpOrchestrator::new_test();

orchestrator.tool_inventory().insert_entry(
ToolEntry::from_server_tool("server1", create_test_tool("read_tool"))
.with_annotations(ToolAnnotations::new().with_read_only(true)),
);
orchestrator.tool_inventory().insert_entry(
ToolEntry::from_server_tool("server1", create_test_tool("write_tool"))
.with_annotations(ToolAnnotations::new().with_read_only(false)),
);
orchestrator
.tool_inventory()
.insert_entry(ToolEntry::from_server_tool(
"server1",
create_test_tool("missing_hint_tool"),
));

let session = McpToolSession::new(
&orchestrator,
vec![McpServerBinding {
label: "mock".to_string(),
server_key: "server1".to_string(),
allowed_tools: Some(McpToolExposureFilter::read_only(true)),
}],
"test-request",
);

let names: HashSet<&str> = session
.mcp_tools()
.iter()
.map(|entry| entry.tool_name())
.collect();
assert_eq!(names, HashSet::from(["read_tool"]));
assert!(session.has_exposed_tool("read_tool"));
assert!(!session.has_exposed_tool("write_tool"));
assert!(!session.has_exposed_tool("missing_hint_tool"));
}

#[test]
fn test_allowed_tools_read_only_filter_combines_with_tool_names() {
let orchestrator = McpOrchestrator::new_test();

orchestrator.tool_inventory().insert_entry(
ToolEntry::from_server_tool("server1", create_test_tool("allowed_read_tool"))
.with_annotations(ToolAnnotations::new().with_read_only(true)),
);
orchestrator.tool_inventory().insert_entry(
ToolEntry::from_server_tool("server1", create_test_tool("other_read_tool"))
.with_annotations(ToolAnnotations::new().with_read_only(true)),
);
orchestrator.tool_inventory().insert_entry(
ToolEntry::from_server_tool("server1", create_test_tool("allowed_write_tool"))
.with_annotations(ToolAnnotations::new().with_read_only(false)),
);

let session = McpToolSession::new(
&orchestrator,
vec![McpServerBinding {
label: "mock".to_string(),
server_key: "server1".to_string(),
allowed_tools: Some(McpToolExposureFilter::new(
Some(vec![
"allowed_read_tool".to_string(),
"allowed_write_tool".to_string(),
]),
Some(true),
)),
}],
"test-request",
);

let names: HashSet<&str> = session
.mcp_tools()
.iter()
.map(|entry| entry.tool_name())
.collect();
assert_eq!(names, HashSet::from(["allowed_read_tool"]));
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

#[test]
fn test_alias_tools_replace_target_tool_in_session_inventory() {
let orchestrator = McpOrchestrator::new_test();
Expand Down Expand Up @@ -1427,7 +1576,9 @@ mod tests {
vec![McpServerBinding {
label: "brave".to_string(),
server_key: "server1".to_string(),
allowed_tools: Some(vec!["web_search".to_string()]),
allowed_tools: Some(McpToolExposureFilter::tool_names(vec![
"web_search".to_string()
])),
}],
"test-request",
);
Expand All @@ -1437,6 +1588,53 @@ mod tests {
assert_eq!(session.mcp_tools()[0].tool_name(), "web_search");
}

#[test]
fn test_allowed_tools_read_only_filter_uses_alias_target_annotations() {
let orchestrator = McpOrchestrator::new_test();

orchestrator.tool_inventory().insert_entry(
ToolEntry::from_server_tool("server1", create_test_tool("brave_web_search"))
.with_annotations(ToolAnnotations::new().with_read_only(true)),
);

orchestrator
.register_alias(
"web_search",
"server1",
"brave_web_search",
None,
ResponseFormat::WebSearchCall,
)
.expect("alias registration should succeed");

let read_only_session = McpToolSession::new(
&orchestrator,
vec![McpServerBinding {
label: "brave".to_string(),
server_key: "server1".to_string(),
allowed_tools: Some(McpToolExposureFilter::read_only(true)),
}],
"test-request",
);

assert!(read_only_session.has_exposed_tool("web_search"));
assert_eq!(read_only_session.mcp_tools().len(), 1);
assert_eq!(read_only_session.mcp_tools()[0].tool_name(), "web_search");

let write_session = McpToolSession::new(
&orchestrator,
vec![McpServerBinding {
label: "brave".to_string(),
server_key: "server1".to_string(),
allowed_tools: Some(McpToolExposureFilter::read_only(false)),
}],
"test-request",
);

assert!(!write_session.has_exposed_tool("web_search"));
assert!(write_session.mcp_tools().is_empty());
}

#[test]
fn test_is_internal_tool_for_internal_server() {
use crate::core::config::{McpConfig, McpServerConfig, McpTransport};
Expand Down Expand Up @@ -1752,7 +1950,9 @@ mod tests {
McpServerBinding {
label: "brave".to_string(),
server_key: "server1".to_string(),
allowed_tools: Some(vec!["brave_web_search".to_string()]),
allowed_tools: Some(McpToolExposureFilter::tool_names(vec![
"brave_web_search".to_string(),
])),
},
McpServerBinding {
label: "deepwiki".to_string(),
Expand Down
9 changes: 5 additions & 4 deletions crates/mcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ pub use core::{config, pool as connection_pool};
pub use core::{
ArgMappingConfig, BuiltinToolType, ConfigValidationError, HandlerRequestContext,
LatencySnapshot, McpConfig, McpMetrics, McpOrchestrator, McpRequestContext, McpServerBinding,
McpServerConfig, McpToolSession, McpTransport, MetricsSnapshot, PendingToolExecution,
PolicyConfig, PolicyDecisionConfig, PoolKey, RefreshRequest, ResponseFormatConfig,
ServerPolicyConfig, SmgClientHandler, Tool, ToolCallResult, ToolConfig, ToolExecutionInput,
ToolExecutionOutput, ToolExecutionResult, TrustLevelConfig, DEFAULT_SERVER_LABEL,
McpServerConfig, McpToolExposureFilter, McpToolSession, McpTransport, MetricsSnapshot,
PendingToolExecution, PolicyConfig, PolicyDecisionConfig, PoolKey, RefreshRequest,
ResponseFormatConfig, ServerPolicyConfig, SmgClientHandler, Tool, ToolCallResult, ToolConfig,
ToolExecutionInput, ToolExecutionOutput, ToolExecutionResult, TrustLevelConfig,
DEFAULT_SERVER_LABEL,
};

// Re-export shared types
Expand Down
4 changes: 3 additions & 1 deletion model_gateway/src/routers/anthropic/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ impl RouterTrait for AnthropicRouter {
url: Some(server.url.clone()),
authorization: server.authorization_token.clone(),
headers: HashMap::new(),
allowed_tools: toolset_allowed.get(&server.name).and_then(|v| v.clone()),
allowed_tools: toolset_allowed
.get(&server.name)
.and_then(|v| v.clone().map(smg_mcp::McpToolExposureFilter::tool_names)),
})
.collect();

Expand Down
Loading
Loading