Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions crates/mcp/src/core/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ pub enum BuiltinToolType {
FileSearch,
/// Image generation tool (OpenAI: image_generation)
ImageGeneration,
/// Shell tool (OpenAI: shell)
Shell,
}

impl BuiltinToolType {
Expand All @@ -314,6 +316,7 @@ impl BuiltinToolType {
BuiltinToolType::CodeInterpreter => ResponseFormatConfig::CodeInterpreterCall,
BuiltinToolType::FileSearch => ResponseFormatConfig::FileSearchCall,
BuiltinToolType::ImageGeneration => ResponseFormatConfig::ImageGenerationCall,
BuiltinToolType::Shell => ResponseFormatConfig::ShellCall,
}
}
}
Expand All @@ -325,6 +328,7 @@ impl fmt::Display for BuiltinToolType {
BuiltinToolType::CodeInterpreter => write!(f, "code_interpreter"),
BuiltinToolType::FileSearch => write!(f, "file_search"),
BuiltinToolType::ImageGeneration => write!(f, "image_generation"),
BuiltinToolType::Shell => write!(f, "shell"),
}
}
}
Expand Down Expand Up @@ -355,6 +359,7 @@ pub enum ResponseFormatConfig {
CodeInterpreterCall,
FileSearchCall,
ImageGenerationCall,
ShellCall,
}

/// Argument mapping configuration for tool aliases.
Expand Down Expand Up @@ -1034,6 +1039,7 @@ tools:
ResponseFormatConfig::ImageGenerationCall,
"\"image_generation_call\"",
),
(ResponseFormatConfig::ShellCall, "\"shell_call\""),
];

for (format, expected) in formats {
Expand Down Expand Up @@ -1201,6 +1207,7 @@ policy:
(BuiltinToolType::CodeInterpreter, "\"code_interpreter\""),
(BuiltinToolType::FileSearch, "\"file_search\""),
(BuiltinToolType::ImageGeneration, "\"image_generation\""),
(BuiltinToolType::Shell, "\"shell\""),
];

for (builtin_type, expected) in types {
Expand Down Expand Up @@ -1230,6 +1237,10 @@ policy:
BuiltinToolType::ImageGeneration.response_format(),
ResponseFormatConfig::ImageGenerationCall
);
assert_eq!(
BuiltinToolType::Shell.response_format(),
ResponseFormatConfig::ShellCall
);
}

#[test]
Expand Down Expand Up @@ -1522,5 +1533,10 @@ servers:
"code_interpreter"
);
assert_eq!(BuiltinToolType::FileSearch.to_string(), "file_search");
assert_eq!(
BuiltinToolType::ImageGeneration.to_string(),
"image_generation"
);
assert_eq!(BuiltinToolType::Shell.to_string(), "shell");
}
}
1 change: 1 addition & 0 deletions crates/mcp/src/core/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ impl<'a> McpToolSession<'a> {
BuiltinToolType::WebSearchPreview,
BuiltinToolType::CodeInterpreter,
BuiltinToolType::FileSearch,
BuiltinToolType::Shell,
Comment thread
RohanSogani marked this conversation as resolved.
]
.into_iter()
.filter_map(|builtin_type| orchestrator.find_builtin_server(builtin_type))
Expand Down
117 changes: 115 additions & 2 deletions crates/mcp/src/transform/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

use openai_protocol::responses::{
CodeInterpreterCallStatus, CodeInterpreterOutput, FileSearchCallStatus, FileSearchResult,
ImageGenerationCallStatus, ResponseOutputItem, WebSearchAction, WebSearchCallStatus,
WebSearchSource,
ImageGenerationCallStatus, ResponseOutputItem, ShellCallAction, ShellCallStatus,
WebSearchAction, WebSearchCallStatus, WebSearchSource,
};
use tracing::warn;

Expand Down Expand Up @@ -101,6 +101,7 @@ impl ResponseTransformer {
ResponseFormat::ImageGenerationCall => {
Self::to_image_generation_call(result, tool_call_id)
}
ResponseFormat::ShellCall => Self::to_shell_call(tool_call_id, arguments),
}
}

Expand Down Expand Up @@ -586,6 +587,54 @@ impl ResponseTransformer {
attributes: None,
})
}

/// Transform to shell_call output.
fn to_shell_call(tool_call_id: &str, arguments: &str) -> ResponseOutputItem {
let action = parse_shell_call_action(arguments);

ResponseOutputItem::ShellCall {
id: normalize_shell_call_id(tool_call_id),
call_id: tool_call_id.to_string(),
action,
environment: None,
status: ShellCallStatus::Completed,
created_by: None,
}
}
}

fn parse_shell_call_action(arguments: &str) -> ShellCallAction {
let Ok(value) = serde_json::from_str::<serde_json::Value>(arguments) else {
warn!("Failed to parse shell_call arguments as JSON; emitting empty action");
return empty_shell_call_action();
};

let Some(object) = value.as_object() else {
warn!("Expected shell_call arguments to be a JSON object; emitting empty action");
return empty_shell_call_action();
};

let action = serde_json::json!({
"commands": object.get("commands").cloned().unwrap_or_else(|| serde_json::json!([])),
"max_output_length": object.get("max_output_length").cloned().unwrap_or(serde_json::Value::Null),
"timeout_ms": object.get("timeout_ms").cloned().unwrap_or(serde_json::Value::Null),
});

serde_json::from_value::<ShellCallAction>(action).unwrap_or_else(|e| {
warn!(
error = %e,
"Failed to parse shell_call action fields; emitting empty action"
);
empty_shell_call_action()
})
}

fn empty_shell_call_action() -> ShellCallAction {
ShellCallAction {
commands: Vec::new(),
max_output_length: None,
timeout_ms: None,
}
}

/// Strip the base64 `result` payload from an `ImageGenerationCall` output
Expand Down Expand Up @@ -633,6 +682,18 @@ fn parse_text_block_payload(item: &serde_json::Value) -> Option<serde_json::Valu
}
}

fn normalize_shell_call_id(source_id: &str) -> String {
if source_id.starts_with("sc_") {
return source_id.to_string();
}

source_id
.strip_prefix("fc_")
.or_else(|| source_id.strip_prefix("call_"))
.map(|stripped| format!("sc_{stripped}"))
.unwrap_or_else(|| format!("sc_{source_id}"))
}

#[cfg(test)]
mod tests {
use serde_json::json;
Expand Down Expand Up @@ -1022,6 +1083,58 @@ mod tests {
}
}

#[test]
fn test_shell_call_transform() {
let transformed = ResponseTransformer::transform(
&json!({}),
&ResponseFormat::ShellCall,
"call-shell-1",
"server",
"shell",
r#"{"commands":["echo hello"],"timeout_ms":1000}"#,
);

match transformed {
ResponseOutputItem::ShellCall {
id,
call_id,
action,
environment,
status,
..
} => {
assert_eq!(id, "sc_call-shell-1");
assert_eq!(call_id, "call-shell-1");
assert_eq!(action.commands, vec!["echo hello"]);
assert_eq!(action.timeout_ms, Some(1000));
assert!(environment.is_none());
assert_eq!(status, ShellCallStatus::Completed);
}
_ => panic!("Expected ShellCall"),
}
}

#[test]
fn test_shell_call_transform_preserves_action_with_dispatch_metadata() {
let transformed = ResponseTransformer::transform(
&json!({}),
&ResponseFormat::ShellCall,
"call-shell-2",
"server",
"shell",
r#"{"commands":["pwd"],"timeout_ms":500,"user":"request-user"}"#,
);

match transformed {
ResponseOutputItem::ShellCall { action, .. } => {
assert_eq!(action.commands, vec!["pwd"]);
assert_eq!(action.timeout_ms, Some(500));
assert_eq!(action.max_output_length, None);
}
_ => panic!("Expected ShellCall"),
}
}

#[test]
fn test_file_search_transform() {
let result = json!({
Expand Down
6 changes: 6 additions & 0 deletions crates/mcp/src/transform/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub enum ResponseFormat {
FileSearchCall,
/// Transform to OpenAI image_generation_call format
ImageGenerationCall,
/// Transform to OpenAI shell_call format
ShellCall,
}

impl ResponseFormat {
Expand All @@ -36,6 +38,7 @@ impl ResponseFormat {
ResponseFormat::CodeInterpreterCall => Some(BuiltinToolType::CodeInterpreter),
ResponseFormat::FileSearchCall => Some(BuiltinToolType::FileSearch),
ResponseFormat::ImageGenerationCall => Some(BuiltinToolType::ImageGeneration),
ResponseFormat::ShellCall => Some(BuiltinToolType::Shell),
}
}
}
Expand All @@ -48,6 +51,7 @@ impl From<ResponseFormatConfig> for ResponseFormat {
ResponseFormatConfig::CodeInterpreterCall => ResponseFormat::CodeInterpreterCall,
ResponseFormatConfig::FileSearchCall => ResponseFormat::FileSearchCall,
ResponseFormatConfig::ImageGenerationCall => ResponseFormat::ImageGenerationCall,
ResponseFormatConfig::ShellCall => ResponseFormat::ShellCall,
}
}
}
Expand All @@ -70,6 +74,7 @@ mod tests {
ResponseFormat::ImageGenerationCall,
"\"image_generation_call\"",
),
(ResponseFormat::ShellCall, "\"shell_call\""),
];

for (format, expected) in formats {
Expand All @@ -95,6 +100,7 @@ mod tests {
BuiltinToolType::CodeInterpreter,
BuiltinToolType::FileSearch,
BuiltinToolType::ImageGeneration,
BuiltinToolType::Shell,
];
for kind in kinds {
let fmt: ResponseFormat = kind.response_format().into();
Expand Down
49 changes: 48 additions & 1 deletion model_gateway/src/routers/common/mcp_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ pub fn collect_builtin_routing(
ResponseTool::WebSearchPreview(_) => BuiltinToolType::WebSearchPreview,
ResponseTool::CodeInterpreter(_) => BuiltinToolType::CodeInterpreter,
ResponseTool::ImageGeneration(_) => BuiltinToolType::ImageGeneration,
ResponseTool::Shell(_) => BuiltinToolType::Shell,
Comment thread
RohanSogani marked this conversation as resolved.
_ => continue,
};

Expand Down Expand Up @@ -229,6 +230,7 @@ pub fn extract_builtin_types(tools: &[ResponseTool]) -> Vec<BuiltinToolType> {
ResponseTool::WebSearchPreview(_) => Some(BuiltinToolType::WebSearchPreview),
ResponseTool::CodeInterpreter(_) => Some(BuiltinToolType::CodeInterpreter),
ResponseTool::ImageGeneration(_) => Some(BuiltinToolType::ImageGeneration),
ResponseTool::Shell(_) => Some(BuiltinToolType::Shell),
Comment thread
RohanSogani marked this conversation as resolved.
_ => None,
})
.collect()
Expand Down Expand Up @@ -414,7 +416,7 @@ mod tests {
common::Function,
responses::{
CodeInterpreterTool, FunctionTool, ImageGenerationTool, McpTool, ResponseTool,
WebSearchPreviewTool,
ShellTool, WebSearchPreviewTool,
},
};
use serde_json::json;
Expand Down Expand Up @@ -699,6 +701,51 @@ mod tests {
);
}

#[tokio::test]
async fn test_collect_builtin_routing_shell() {
let mut shell_tools = HashMap::new();
shell_tools.insert(
"execute_shell_commands".to_string(),
ToolConfig {
response_format: ResponseFormatConfig::ShellCall,
..Default::default()
},
);

let config = McpConfig {
servers: vec![McpServerConfig {
name: "shell-server".to_string(),
transport: McpTransport::Streamable {
url: "http://localhost:9996/shell".to_string(),
token: None,
headers: HashMap::new(),
},
proxy: None,
required: false,
tools: Some(shell_tools),
builtin_type: Some(BuiltinToolType::Shell),
builtin_tool_name: Some("execute_shell_commands".to_string()),
internal: false,
}],
pool: Default::default(),
proxy: None,
warmup: Vec::new(),
inventory: Default::default(),
policy: Default::default(),
};

let orchestrator = Arc::new(McpOrchestrator::new(config).await.unwrap());
let tools = vec![ResponseTool::Shell(ShellTool::default())];

let routing = collect_builtin_routing(&orchestrator, Some(&tools));

assert_eq!(routing.len(), 1);
assert_eq!(routing[0].builtin_type, BuiltinToolType::Shell);
assert_eq!(routing[0].server_name, "shell-server");
assert_eq!(routing[0].tool_name, "execute_shell_commands");
assert_eq!(routing[0].response_format, ResponseFormat::ShellCall);
}

// =========================================================================
// ensure_request_mcp_client tests
// =========================================================================
Expand Down
Loading
Loading