diff --git a/crates/mcp/Cargo.toml b/crates/mcp/Cargo.toml index 0f83a3ff6..78bc314d8 100644 --- a/crates/mcp/Cargo.toml +++ b/crates/mcp/Cargo.toml @@ -17,7 +17,6 @@ categories = ["api-bindings"] name = "smg_mcp" [dependencies] -openai-protocol.workspace = true async-trait.workspace = true axum.workspace = true backoff = { version = "0.4", features = ["tokio", "futures"] } diff --git a/crates/mcp/src/core/config.rs b/crates/mcp/src/core/config.rs index 9de390098..5de5c8896 100644 --- a/crates/mcp/src/core/config.rs +++ b/crates/mcp/src/core/config.rs @@ -336,9 +336,16 @@ pub struct ToolConfig { #[serde(skip_serializing_if = "Option::is_none")] pub alias: Option, - /// Response format for transformation (default: passthrough) - #[serde(default)] - pub response_format: ResponseFormatConfig, + /// Response format for transformation. + /// + /// `None` means "use whatever the surrounding context decides" — for a + /// builtin-routed tool that becomes the builtin's hosted format, otherwise + /// it falls through to `Passthrough`. `Some(Passthrough)` is an *explicit* + /// passthrough request and overrides the builtin default. This distinction + /// lets users add an `alias` or `arg_mapping` to a builtin tool without + /// silently disabling its hosted-format wire shape. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub response_format: Option, /// Argument mapping configuration #[serde(default, skip_serializing_if = "Option::is_none")] @@ -346,7 +353,7 @@ pub struct ToolConfig { } /// Response format configuration (mirrors ResponseFormat but for config). -#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, Default)] #[serde(rename_all = "snake_case")] pub enum ResponseFormatConfig { #[default] @@ -929,7 +936,7 @@ tools: assert_eq!(tool_config.alias, Some("web_search".to_string())); assert_eq!( tool_config.response_format, - ResponseFormatConfig::WebSearchCall + Some(ResponseFormatConfig::WebSearchCall) ); let arg_mapping = tool_config.arg_mapping.as_ref().unwrap(); @@ -994,7 +1001,7 @@ tools: assert!(tool_config.alias.is_none()); assert_eq!( tool_config.response_format, - ResponseFormatConfig::FileSearchCall + Some(ResponseFormatConfig::FileSearchCall) ); assert!(tool_config.arg_mapping.is_none()); } @@ -1013,10 +1020,10 @@ tools: let tools = config.tools.as_ref().unwrap(); let tool_config = tools.get("my_tool").unwrap(); assert!(tool_config.alias.is_none()); - assert_eq!( - tool_config.response_format, - ResponseFormatConfig::Passthrough - ); + // `my_tool: {}` carries no `response_format` field, so the deserialized + // value is `None` (meaning "inherit context"), not an explicit + // `Passthrough`. + assert_eq!(tool_config.response_format, None); assert!(tool_config.arg_mapping.is_none()); } @@ -1067,15 +1074,23 @@ tools: let tool_a = tools.get("tool_a").unwrap(); assert_eq!(tool_a.alias, Some("a".to_string())); - assert_eq!(tool_a.response_format, ResponseFormatConfig::WebSearchCall); + assert_eq!( + tool_a.response_format, + Some(ResponseFormatConfig::WebSearchCall) + ); let tool_b = tools.get("tool_b").unwrap(); assert!(tool_b.alias.is_none()); - assert_eq!(tool_b.response_format, ResponseFormatConfig::FileSearchCall); + assert_eq!( + tool_b.response_format, + Some(ResponseFormatConfig::FileSearchCall) + ); let tool_c = tools.get("tool_c").unwrap(); assert_eq!(tool_c.alias, Some("c".to_string())); - assert_eq!(tool_c.response_format, ResponseFormatConfig::Passthrough); + // Alias-only stanza: no `response_format` field → `None` (meaning + // "inherit context"), not explicit Passthrough. + assert_eq!(tool_c.response_format, None); } #[test] diff --git a/crates/mcp/src/core/mod.rs b/crates/mcp/src/core/mod.rs index d323e4b90..6a6cdc25e 100644 --- a/crates/mcp/src/core/mod.rs +++ b/crates/mcp/src/core/mod.rs @@ -22,7 +22,7 @@ pub use config::{ pub use handler::{HandlerRequestContext, RefreshRequest, SmgClientHandler}; pub use metrics::{LatencySnapshot, McpMetrics, MetricsSnapshot}; pub use orchestrator::{ - McpOrchestrator, McpRequestContext, PendingToolExecution, ToolCallResult, ToolExecutionInput, + McpOrchestrator, McpRequestContext, PendingToolExecution, ToolExecutionInput, ToolExecutionOutput, ToolExecutionResult, }; pub use pool::{McpConnectionPool, PoolKey}; diff --git a/crates/mcp/src/core/orchestrator.rs b/crates/mcp/src/core/orchestrator.rs index d60c6d79f..e0197cdcc 100644 --- a/crates/mcp/src/core/orchestrator.rs +++ b/crates/mcp/src/core/orchestrator.rs @@ -37,7 +37,6 @@ use std::{ }; use dashmap::DashMap; -use openai_protocol::responses::ResponseOutputItem; use rmcp::{ model::{CallToolRequestParam, CallToolResult}, service::{RunningService, ServiceError}, @@ -63,9 +62,9 @@ use crate::{ error::{McpError, McpResult}, inventory::{ AliasTarget, ArgMapping, QualifiedToolName, ToolCategory, ToolEntry, ToolInventory, + ALIAS_SERVER_KEY, }, tenant::TenantContext, - transform::{ResponseFormat, ResponseTransformer}, }; /// Build request headers from token and custom headers. @@ -130,20 +129,7 @@ struct ServerEntry { config: McpServerConfig, } -/// Result of a tool call. -#[derive(Debug)] -pub enum ToolCallResult { - /// Successfully executed and transformed. - Success(ResponseOutputItem), - /// Pending approval from user. - PendingApproval(McpApprovalRequest), -} - /// Internal result type for approval-checked execution. -/// -/// Used to avoid code duplication between `execute_tool_with_approval` and -/// `execute_tool_with_approval_raw`. Contains either the raw result or -/// the pending approval request. enum ApprovalExecutionResult { /// Tool executed successfully, contains raw result. Success(CallToolResult), @@ -151,51 +137,6 @@ enum ApprovalExecutionResult { PendingApproval(McpApprovalRequest), } -impl ToolCallResult { - /// Get the transformed output item directly without serialization. - /// - /// Returns `Some(item)` for successful tool calls, `None` for pending approvals. - /// This avoids the serialize/deserialize roundtrip when the caller needs - /// the item as a Value or ResponseOutputItem. - pub fn into_item(self) -> Option { - match self { - ToolCallResult::Success(item) => Some(item), - ToolCallResult::PendingApproval(_) => None, - } - } - - /// Convert the result to a serialized output tuple. - /// - /// Returns `(output_str, is_error, error_message)` suitable for - /// recording in conversation history or emitting as events. - /// - /// This centralizes the result serialization logic that was previously - /// duplicated across all routers. - pub fn into_serialized(self) -> (String, bool, Option) { - match self { - ToolCallResult::Success(item) => match serde_json::to_string(&item) { - Ok(s) => (s, false, None), - Err(e) => { - let err = format!("Failed to serialize tool result: {e}"); - ( - serde_json::json!({"error": &err}).to_string(), - true, - Some(err), - ) - } - }, - ToolCallResult::PendingApproval(_) => { - let err = "Tool requires approval (not supported in this context)".to_string(); - ( - serde_json::json!({"error": &err}).to_string(), - true, - Some(err), - ) - } - } - } -} - // ============================================================================ // Batch Tool Execution Types // ============================================================================ @@ -216,12 +157,10 @@ pub struct ToolExecutionInput { /// Output from batch tool execution. /// -/// Contains all information needed by routers to build responses, -/// record state, and emit events. The MCP crate handles: -/// - Tool lookup and execution -/// - Result serialization and transformation -/// - Error handling +/// `#[non_exhaustive]` so additive fields don't break consumers; only +/// `smg-mcp` constructs this. #[derive(Debug, Clone)] +#[non_exhaustive] pub struct ToolExecutionOutput { /// The call_id from the input (for matching). pub call_id: String, @@ -241,8 +180,6 @@ pub struct ToolExecutionOutput { pub is_error: bool, /// Error message if `is_error` is true. pub error_message: Option, - /// Response format for transforming output to API-specific types. - pub response_format: ResponseFormat, /// Execution duration. pub duration: Duration, } @@ -256,6 +193,7 @@ pub enum ToolExecutionResult { /// Pending approval from resolved tool execution. #[derive(Debug, Clone)] +#[non_exhaustive] pub struct PendingToolExecution { pub call_id: String, pub tool_name: String, @@ -263,7 +201,6 @@ pub struct PendingToolExecution { pub server_label: String, pub arguments_str: String, pub approval_request: McpApprovalRequest, - pub response_format: ResponseFormat, pub duration: Duration, } @@ -286,33 +223,12 @@ impl ToolExecutionResult { error_message: Some( "Tool requires approval (not supported in this context)".to_string(), ), - response_format: pending.response_format, duration: pending.duration, }, } } } -impl ToolExecutionOutput { - /// Get the transformed ResponseOutputItem. - /// - /// Transforms the raw output to the appropriate ResponseOutputItem type - /// based on the tool's configured response format (WebSearchCall, - /// CodeInterpreterCall, FileSearchCall, or Passthrough/McpCall). - /// - /// Uses `server_label` (user-facing) for the output, not `server_key` (internal). - pub fn to_response_item(&self) -> ResponseOutputItem { - ResponseTransformer::transform( - &self.output, - &self.response_format, - &self.call_id, - &self.server_label, - &self.tool_name, - &self.arguments_str, - ) - } -} - /// Main orchestrator for MCP operations. /// /// Thread-safe and designed for sharing across async tasks. @@ -484,12 +400,8 @@ impl McpOrchestrator { // Load tools from server self.load_server_inventory(&config.name, &client).await; - // Apply tool configs (aliases, response formats) self.apply_tool_configs(config); - // Apply builtin response format if server has builtin_type configured - self.apply_builtin_response_format(config); - // Store server entry with config for builtin lookups self.static_servers.insert( config.name.clone(), @@ -505,7 +417,7 @@ impl McpOrchestrator { Ok(()) } - /// Apply tool configurations from server config (aliases, response formats, arg mappings). + /// Apply tool configurations from server config (aliases and arg mappings). fn apply_tool_configs(&self, config: &McpServerConfig) { let Some(tools) = &config.tools else { return; @@ -524,8 +436,6 @@ impl McpOrchestrator { continue; } - // Get the existing entry to update or create alias - let response_format: ResponseFormat = tool_config.response_format.clone().into(); let arg_mapping = tool_config.arg_mapping.as_ref().map(|cfg| { let mut mapping = ArgMapping::new(); for (from, to) in &cfg.renames { @@ -542,92 +452,29 @@ impl McpOrchestrator { // If there's an alias, register it if let Some(alias_name) = &tool_config.alias { - if let Err(e) = self.register_alias( - alias_name, - &config.name, - tool_name, - arg_mapping, - response_format.clone(), - ) { + if let Err(e) = + self.register_alias(alias_name, &config.name, tool_name, arg_mapping) + { warn!( "Failed to register alias '{}' for '{}:{}': {}", alias_name, config.name, tool_name, e ); } else { info!( - "Registered alias '{}' → '{}:{}' with format {:?}", - alias_name, config.name, tool_name, response_format + "Registered alias '{}' → '{}:{}'", + alias_name, config.name, tool_name ); } - } else if response_format != ResponseFormat::Passthrough { - // No alias, but has custom response format - update the entry directly + } else if let Some(mapping) = arg_mapping { + // No alias but has arg_mapping - update entry directly if let Some(mut entry) = self.tool_inventory.get_entry(&config.name, tool_name) { - entry.response_format = response_format.clone(); - entry.arg_mapping.clone_from(&arg_mapping); + entry.arg_mapping = Some(mapping); self.tool_inventory.insert_entry(entry); - info!( - "Set response format {:?} for '{}:{}'", - response_format, config.name, tool_name - ); } } } } - /// Apply builtin response format to the builtin_tool_name if not explicitly overridden. - /// - /// When a server is configured with `builtin_type` and `builtin_tool_name`, the - /// corresponding tool should use the response format associated with the builtin type - /// (e.g., WebSearchPreview -> WebSearchCall) unless explicitly overridden in the tools config. - fn apply_builtin_response_format(&self, config: &McpServerConfig) { - let Some(builtin_type) = &config.builtin_type else { - return; - }; - let Some(tool_name) = &config.builtin_tool_name else { - return; - }; - - let has_explicit_config = config - .tools - .as_ref() - .is_some_and(|tools| tools.contains_key(tool_name)); - - if has_explicit_config { - debug!( - server = %config.name, - tool = %tool_name, - "Builtin tool has explicit config, skipping auto-apply of response_format" - ); - return; - } - - let response_format: ResponseFormat = builtin_type.response_format().into(); - - let updated = self - .tool_inventory - .update_entry(&config.name, tool_name, |entry| { - if entry.response_format != response_format { - info!( - server = %config.name, - tool = %tool_name, - builtin_type = %builtin_type, - format = ?response_format, - "Applied builtin response format" - ); - entry.response_format = response_format.clone(); - } - }); - - if !updated { - warn!( - server = %config.name, - tool = %tool_name, - builtin_type = %builtin_type, - "Builtin tool not found on server" - ); - } - } - /// Internal server connection logic. async fn connect_server_impl( &self, @@ -818,112 +665,19 @@ impl McpOrchestrator { // Tool Execution // ======================================================================== - /// Call a tool with approval checking and response transformation. - /// - /// This is the main entry point for tool execution. - /// - /// # Arguments - /// * `server_key` - Internal server identifier (may be URL for dynamic servers) - /// * `tool_name` - The tool name to execute - /// * `arguments` - Tool arguments as JSON - /// * `server_label` - User-facing label for API responses - /// * `request_ctx` - Request context for approval - pub async fn call_tool( - &self, - server_key: &str, - tool_name: &str, - arguments: Value, - server_label: &str, - request_ctx: &McpRequestContext<'_>, - ) -> McpResult { - self.active_executions.fetch_add(1, Ordering::SeqCst); - let _guard = scopeguard::guard(Arc::clone(&self.active_executions), |count| { - count.fetch_sub(1, Ordering::SeqCst); - }); - let qualified = QualifiedToolName::new(server_key, tool_name); - - // Get tool entry - let entry = self - .tool_inventory - .get_entry(server_key, tool_name) - .ok_or_else(|| McpError::ToolNotFound(qualified.to_string()))?; - - // Record metrics start - self.metrics.record_call_start(&qualified); - let start_time = Instant::now(); - - // Execute with approval flow - let result = self - .execute_tool_with_approval(&entry, arguments, server_label, request_ctx) - .await; - - // Record metrics end - let duration_ms = start_time.elapsed().as_millis() as u64; - self.metrics - .record_call_end(&qualified, result.is_ok(), duration_ms); - - result - } - /// Find the MCP server configured to handle a built-in tool type. /// - /// When a request includes built-in tools like `{"type": "web_search_preview"}`, - /// routers can use this method to find which MCP server should handle it. - /// - /// # Arguments - /// * `builtin_type` - The built-in tool type to look up - /// - /// # Returns - /// If a server is configured for this built-in type, returns: - /// - `server_key` - Internal identifier for the server (used for `call_tool`) - /// - `tool_name` - The MCP tool to call on that server - /// - `response_format` - The format to use for response transformation - /// - /// Returns `None` if no server is configured for this built-in type. - /// - /// # Example - /// - /// ```ignore - /// // Check if web_search_preview is configured - /// if let Some((server_key, tool_name, format)) = - /// orchestrator.find_builtin_server(BuiltinToolType::WebSearchPreview) - /// { - /// // Route to MCP server - /// let result = orchestrator.call_tool( - /// &server_key, - /// &tool_name, - /// arguments, - /// "web-search", // user-facing label - /// &request_ctx, - /// ).await?; - /// } else { - /// // No MCP configured - handle differently - /// } - /// ``` - pub fn find_builtin_server( - &self, - builtin_type: BuiltinToolType, - ) -> Option<(String, String, ResponseFormat)> { - // Helper to extract builtin info from a server config + /// Returns `(server_key, tool_name)` for the configured server, or + /// `None` if no server handles this builtin. Per-tool `ResponseFormat` + /// lives in the gateway's `FormatRegistry` — callers look it up there. + pub fn find_builtin_server(&self, builtin_type: BuiltinToolType) -> Option<(String, String)> { let extract_builtin = |server_config: &McpServerConfig| { if let (Some(cfg_type), Some(tool_name)) = ( &server_config.builtin_type, &server_config.builtin_tool_name, ) { if *cfg_type == builtin_type { - // Determine response format from tool config or use builtin default - let response_format = server_config - .tools - .as_ref() - .and_then(|tools| tools.get(tool_name)) - .map(|tc| tc.response_format.clone().into()) - .unwrap_or_else(|| builtin_type.response_format().into()); - - return Some(( - server_config.name.clone(), - tool_name.clone(), - response_format, - )); + return Some((server_config.name.clone(), tool_name.clone())); } } None @@ -1054,7 +808,6 @@ impl McpOrchestrator { output: serde_json::json!({ "error": &err }), is_error: true, error_message: Some(err), - response_format: ResponseFormat::Passthrough, duration: start.elapsed(), }) } @@ -1074,7 +827,6 @@ impl McpOrchestrator { }); self.metrics.record_call_start(&qualified); let call_start_time = Instant::now(); - let response_format = entry.response_format.clone(); let result = match self .execute_tool_with_approval_raw_internal(entry, arguments, request_ctx) @@ -1090,7 +842,6 @@ impl McpOrchestrator { output: Self::call_result_to_json(&raw_result), is_error: raw_result.is_error.unwrap_or(false), error_message: None, - response_format: response_format.clone(), duration: call_start_time.elapsed(), }) } @@ -1102,7 +853,6 @@ impl McpOrchestrator { server_label: entry.server_key().to_string(), arguments_str: String::new(), approval_request, - response_format, duration: call_start_time.elapsed(), }) } @@ -1117,7 +867,6 @@ impl McpOrchestrator { output: serde_json::json!({ "error": &err }), is_error: true, error_message: Some(err), - response_format, duration: call_start_time.elapsed(), }) } @@ -1131,48 +880,8 @@ impl McpOrchestrator { result } - /// Execute tool with approval checking. - /// - /// Returns a transformed `ToolCallResult` ready for API responses. - /// - /// # Arguments - /// * `entry` - Tool entry to execute - /// * `arguments` - Tool arguments - /// * `server_label` - User-facing server label for API responses - /// * `request_ctx` - Request context for approval - async fn execute_tool_with_approval( - &self, - entry: &ToolEntry, - arguments: Value, - server_label: &str, - request_ctx: &McpRequestContext<'_>, - ) -> McpResult { - // Delegate to raw implementation and transform result - match self - .execute_tool_with_approval_raw_internal(entry, arguments.clone(), request_ctx) - .await? - { - ApprovalExecutionResult::Success(result) => { - let output = Self::transform_result( - &result, - &entry.response_format, - &request_ctx.request_id, - server_label, - entry.tool_name(), - &arguments.to_string(), - ); - Ok(ToolCallResult::Success(output)) - } - ApprovalExecutionResult::PendingApproval(approval_request) => { - Ok(ToolCallResult::PendingApproval(approval_request)) - } - } - } - /// Internal implementation of approval-checked tool execution. - /// - /// Returns either the raw result or the pending approval request. - /// Both public methods delegate to this to avoid code duplication. + /// Returns either the raw `CallToolResult` or a pending approval request. async fn execute_tool_with_approval_raw_internal( &self, entry: &ToolEntry, @@ -1401,28 +1110,6 @@ impl McpOrchestrator { args } - /// Transform MCP result to OpenAI format. - fn transform_result( - result: &CallToolResult, - format: &ResponseFormat, - tool_call_id: &str, - server_label: &str, - tool_name: &str, - arguments: &str, - ) -> ResponseOutputItem { - // Convert CallToolResult content to JSON for transformation - let result_json = Self::call_result_to_json(result); - - ResponseTransformer::transform( - &result_json, - format, - tool_call_id, - server_label, - tool_name, - arguments, - ) - } - /// Convert CallToolResult to JSON value. fn call_result_to_json(result: &CallToolResult) -> Value { // Serialize the CallToolResult content to JSON @@ -1477,7 +1164,6 @@ impl McpOrchestrator { target_server: &str, target_tool: &str, arg_mapping: Option, - response_format: ResponseFormat, ) -> McpResult<()> { // Verify target exists let target_entry = self @@ -1492,11 +1178,10 @@ impl McpOrchestrator { }; let alias_entry = ToolEntry::new( - QualifiedToolName::new("alias", alias_name), + QualifiedToolName::new(ALIAS_SERVER_KEY, alias_name), target_entry.tool.clone(), ) - .with_alias(alias_target) - .with_response_format(response_format); + .with_alias(alias_target); self.tool_inventory.insert_entry(alias_entry); @@ -1825,39 +1510,6 @@ impl McpOrchestrator { } } - /// Call a tool and continue execution after approval (for continuing paused requests). - /// - /// This is called after the user approves a tool execution in interactive mode. - /// The approval should already be resolved via `resolve_approval()`. - pub async fn continue_tool_execution( - &self, - server_key: &str, - tool_name: &str, - arguments: Value, - request_ctx: &McpRequestContext<'_>, - ) -> McpResult { - // Get tool entry - let entry = self - .tool_inventory - .get_entry(server_key, tool_name) - .ok_or_else(|| McpError::ToolNotFound(format!("{server_key}:{tool_name}")))?; - - // Execute directly (approval already handled) - let result = self.execute_tool_impl(&entry, arguments.clone()).await?; - - // Transform response - let output = Self::transform_result( - &result, - &entry.response_format, - &request_ctx.request_id, - entry.server_key(), - entry.tool_name(), - &arguments.to_string(), - ); - - Ok(ToolCallResult::Success(output)) - } - // ======================================================================== // Lifecycle // ======================================================================== @@ -1973,80 +1625,6 @@ impl<'a> McpRequestContext<'a> { Ok(()) } - /// Call a tool in this request context. - /// - /// # Arguments - /// * `server_key` - Internal server identifier - /// * `tool_name` - Tool name to execute - /// * `arguments` - Tool arguments as JSON - /// * `server_label` - User-facing label for API responses - pub async fn call_tool( - &self, - server_key: &str, - tool_name: &str, - arguments: Value, - server_label: &str, - ) -> McpResult { - self.orchestrator - .active_executions - .fetch_add(1, Ordering::SeqCst); - let _guard = scopeguard::guard(Arc::clone(&self.orchestrator.active_executions), |count| { - count.fetch_sub(1, Ordering::SeqCst); - }); - // Check dynamic tools first - let qualified = QualifiedToolName::new(server_key, tool_name); - if let Some(entry) = self.dynamic_tools.get(&qualified) { - return self - .execute_dynamic_tool(&entry, arguments, server_label) - .await; - } - - // Fall back to orchestrator - self.orchestrator - .call_tool(server_key, tool_name, arguments, server_label, self) - .await - } - - /// Execute a dynamic tool. - async fn execute_dynamic_tool( - &self, - entry: &ToolEntry, - arguments: Value, - server_label: &str, - ) -> McpResult { - let client = self - .dynamic_clients - .get(entry.server_key()) - .ok_or_else(|| McpError::ServerNotFound(entry.server_key().to_string()))?; - - let args_map = if let Value::Object(map) = arguments.clone() { - Some(map) - } else { - None - }; - - let request = CallToolRequestParam { - name: Cow::Owned(entry.tool_name().to_string()), - arguments: args_map, - }; - - let result = client - .call_tool(request) - .await - .map_err(|e| McpError::ToolExecution(format!("MCP call failed: {e}")))?; - - let output = McpOrchestrator::transform_result( - &result, - &entry.response_format, - &self.request_id, - server_label, - entry.tool_name(), - &arguments.to_string(), - ); - - Ok(ToolCallResult::Success(output)) - } - /// List all tools visible in this request context. pub fn list_tools(&self) -> Vec { let mut tools = self.orchestrator.list_tools(Some(&self.tenant_ctx)); @@ -2288,13 +1866,7 @@ mod tests { orchestrator.tool_inventory.insert_entry(entry); // Register alias - let result = orchestrator.register_alias( - "web_search", - "brave", - "brave_web_search", - None, - ResponseFormat::WebSearchCall, - ); + let result = orchestrator.register_alias("web_search", "brave", "brave_web_search", None); assert!(result.is_ok()); assert!(orchestrator @@ -2307,13 +1879,8 @@ mod tests { fn test_alias_registration_missing_target() { let orchestrator = McpOrchestrator::new_test(); - let result = orchestrator.register_alias( - "web_search", - "missing_server", - "missing_tool", - None, - ResponseFormat::Passthrough, - ); + let result = + orchestrator.register_alias("web_search", "missing_server", "missing_tool", None); assert!(result.is_err()); } @@ -2369,7 +1936,7 @@ mod tests { "search_web".to_string(), ToolConfig { alias: None, - response_format: ResponseFormatConfig::WebSearchCall, + response_format: Some(ResponseFormatConfig::WebSearchCall), arg_mapping: Some(ArgMappingConfig { renames: HashMap::new(), defaults: HashMap::from([( @@ -2414,7 +1981,6 @@ mod tests { mapping.overrides, vec![("enable_brave".to_string(), serde_json::json!(false))] ); - assert_eq!(entry.response_format, ResponseFormat::WebSearchCall); } #[test] @@ -2592,10 +2158,9 @@ mod tests { let result = orchestrator.find_builtin_server(BuiltinToolType::WebSearchPreview); assert!(result.is_some()); - let (server_key, tool_name, response_format) = result.unwrap(); + let (server_key, tool_name) = result.unwrap(); assert_eq!(server_key, "brave"); assert_eq!(tool_name, "brave_web_search"); - assert_eq!(response_format, ResponseFormat::WebSearchCall); // Should NOT find a server for code_interpreter let result = orchestrator.find_builtin_server(BuiltinToolType::CodeInterpreter); @@ -2617,7 +2182,7 @@ mod tests { "my_search".to_string(), ToolConfig { alias: None, - response_format: ResponseFormatConfig::Passthrough, // Override default + response_format: Some(ResponseFormatConfig::Passthrough), // Override default arg_mapping: None, }, ); @@ -2661,11 +2226,9 @@ mod tests { let result = orchestrator.find_builtin_server(BuiltinToolType::WebSearchPreview); assert!(result.is_some()); - let (server_key, tool_name, response_format) = result.unwrap(); + let (server_key, tool_name) = result.unwrap(); assert_eq!(server_key, "custom-search"); assert_eq!(tool_name, "my_search"); - // Should use the custom Passthrough format, not the default WebSearchCall - assert_eq!(response_format, ResponseFormat::Passthrough); } #[test] @@ -2683,148 +2246,4 @@ mod tests { .find_builtin_server(BuiltinToolType::FileSearch) .is_none()); } - - #[test] - fn test_apply_builtin_response_format() { - use std::collections::HashMap; - - use crate::{ - approval::{audit::AuditLog, policy::PolicyEngine}, - inventory::types::ToolEntry, - }; - - // Create config with builtin_type but no explicit response_format for the tool - let config = McpConfig { - servers: vec![McpServerConfig { - name: "brave".to_string(), - transport: McpTransport::Sse { - url: "http://localhost:3000/sse".to_string(), - token: None, - headers: HashMap::new(), - }, - proxy: None, - required: false, - tools: None, // No explicit tool config - builtin_type: Some(BuiltinToolType::WebSearchPreview), - builtin_tool_name: Some("brave_search".to_string()), - internal: false, - }], - ..Default::default() - }; - - let (refresh_tx, _) = mpsc::channel(10); - let audit_log = Arc::new(AuditLog::new()); - let policy_engine = Arc::new(PolicyEngine::new(Arc::clone(&audit_log))); - let approval_manager = Arc::new(ApprovalManager::new(policy_engine, audit_log)); - - let orchestrator = McpOrchestrator { - static_servers: DashMap::new(), - tool_inventory: Arc::new(ToolInventory::new()), - approval_manager, - connection_pool: Arc::new(McpConnectionPool::new()), - metrics: Arc::new(McpMetrics::new()), - refresh_tx, - active_executions: Arc::new(AtomicUsize::new(0)), - shutdown_token: CancellationToken::new(), - reconnection_locks: DashMap::new(), - config, - }; - - // Simulate tool discovery - tool is registered with default Passthrough - let tool = create_test_tool("brave_search"); - let entry = ToolEntry::from_server_tool("brave", tool); - assert_eq!(entry.response_format, ResponseFormat::Passthrough); // Default - orchestrator.tool_inventory.insert_entry(entry); - - // Apply builtin response format - should update to WebSearchCall - orchestrator.apply_builtin_response_format(&orchestrator.config.servers[0]); - - // Verify the tool entry was updated - let entry = orchestrator - .tool_inventory - .get_entry("brave", "brave_search") - .expect("Tool should exist"); - assert_eq!( - entry.response_format, - ResponseFormat::WebSearchCall, - "Builtin type should auto-apply WebSearchCall format" - ); - } - - #[test] - fn test_apply_builtin_response_format_with_explicit_override() { - use std::collections::HashMap; - - use crate::{ - approval::{audit::AuditLog, policy::PolicyEngine}, - core::config::{ResponseFormatConfig, ToolConfig}, - inventory::types::ToolEntry, - }; - - // Create config with builtin_type AND explicit response_format override - let mut tools = HashMap::new(); - tools.insert( - "brave_search".to_string(), - ToolConfig { - alias: None, - response_format: ResponseFormatConfig::Passthrough, // Explicit override - arg_mapping: None, - }, - ); - - let config = McpConfig { - servers: vec![McpServerConfig { - name: "brave".to_string(), - transport: McpTransport::Sse { - url: "http://localhost:3000/sse".to_string(), - token: None, - headers: HashMap::new(), - }, - proxy: None, - required: false, - tools: Some(tools), - builtin_type: Some(BuiltinToolType::WebSearchPreview), - builtin_tool_name: Some("brave_search".to_string()), - internal: false, - }], - ..Default::default() - }; - - let (refresh_tx, _) = mpsc::channel(10); - let audit_log = Arc::new(AuditLog::new()); - let policy_engine = Arc::new(PolicyEngine::new(Arc::clone(&audit_log))); - let approval_manager = Arc::new(ApprovalManager::new(policy_engine, audit_log)); - - let orchestrator = McpOrchestrator { - static_servers: DashMap::new(), - tool_inventory: Arc::new(ToolInventory::new()), - approval_manager, - connection_pool: Arc::new(McpConnectionPool::new()), - metrics: Arc::new(McpMetrics::new()), - refresh_tx, - active_executions: Arc::new(AtomicUsize::new(0)), - shutdown_token: CancellationToken::new(), - reconnection_locks: DashMap::new(), - config, - }; - - // Simulate tool discovery - let tool = create_test_tool("brave_search"); - let entry = ToolEntry::from_server_tool("brave", tool); - orchestrator.tool_inventory.insert_entry(entry); - - // Apply builtin response format - should NOT override because explicit config exists - orchestrator.apply_builtin_response_format(&orchestrator.config.servers[0]); - - // Verify the tool entry kept Passthrough (explicit override) - let entry = orchestrator - .tool_inventory - .get_entry("brave", "brave_search") - .expect("Tool should exist"); - assert_eq!( - entry.response_format, - ResponseFormat::Passthrough, - "Explicit override should be preserved" - ); - } } diff --git a/crates/mcp/src/core/session.rs b/crates/mcp/src/core/session.rs index 8e72544e5..e3f796871 100644 --- a/crates/mcp/src/core/session.rs +++ b/crates/mcp/src/core/session.rs @@ -9,9 +9,6 @@ use std::collections::{HashMap, HashSet}; use futures::stream::{self, StreamExt}; -use openai_protocol::responses::{ - McpAllowedTools, RequireApproval, RequireApprovalMode, ResponseTool, -}; use super::{ config::BuiltinToolType, @@ -24,12 +21,7 @@ use super::{ use crate::{ approval::ApprovalMode, inventory::{QualifiedToolName, ToolCategory, ToolEntry}, - responses_bridge::{ - build_chat_function_tools_with_names, build_function_tools_json_with_names, - build_mcp_list_tools_item, build_mcp_list_tools_json, build_response_tools_with_names, - }, tenant::TenantContext, - transform::ResponseFormat, }; /// Default user-facing label for MCP servers when no explicit label is provided. @@ -64,7 +56,6 @@ struct ExposedToolBinding { server_label: String, resolved_tool_name: String, is_builtin_routed: bool, - response_format: ResponseFormat, approval_mode: ApprovalMode, } @@ -314,7 +305,6 @@ impl<'a> McpToolSession<'a> { output: serde_json::json!({ "error": &err }), is_error: true, error_message: Some(err), - response_format: ResponseFormat::Passthrough, duration: std::time::Duration::default(), }) } @@ -338,62 +328,24 @@ impl<'a> McpToolSession<'a> { .unwrap_or_else(|| fallback_label.to_string()) } - /// Apply request-time approval configuration to exposed tools in this session. - pub fn configure_response_tools_approval(&mut self, tools: &[ResponseTool]) { - for tool in tools { - let ResponseTool::Mcp(mcp_tool) = tool else { - continue; - }; - - let approval_mode = match mcp_tool.require_approval.as_ref() { - Some(RequireApproval::Mode(RequireApprovalMode::Always)) => { - ApprovalMode::Interactive - } - _ => ApprovalMode::PolicyOnly, - }; - - if approval_mode == ApprovalMode::PolicyOnly { + /// Set the approval mode for every binding matching `server_label`, + /// optionally narrowed to a subset of resolved tool names. + pub fn set_approval_mode( + &mut self, + server_label: &str, + allowed_tool_names: Option<&[String]>, + mode: ApprovalMode, + ) { + for binding in self.exposed_name_map.values_mut() { + if binding.server_label != server_label { continue; } - - // T11: the legacy `allowed_tools: Vec` wire shape is now - // `McpAllowedTools` (untagged union of `List(Vec)` or - // `Filter(McpToolFilter { read_only?, tool_names? })`). Project - // union variants back into the flat name-list scoping used here: - // * `None`, or `Filter { None, None }` → no name constraint - // (all bindings for this server inherit the explicit approval - // mode). - // * `List(names)` / `Filter { tool_names: Some(v), .. }` → - // constrain by explicit names. - // * `Filter { tool_names: None, read_only: Some(_) }` → `None`. - // `readOnlyHint`-based filtering is unimplemented, but the - // safe-default direction for *approval scoping* is the - // opposite of exposure: narrowing to an empty name list here - // would drop the caller's explicit approval mode for all - // bindings (they'd fall back to `PolicyOnly`, which is - // auto-approve-by-policy — LESS restrictive). Returning - // `None` applies the requested approval mode to every - // binding on the server, matching the "over-gate is safer - // than under-gate" contract for approval prompts. - let allowed_tool_names: Option<&[String]> = - mcp_tool.allowed_tools.as_ref().and_then(|at| match at { - McpAllowedTools::List(names) => Some(names.as_slice()), - McpAllowedTools::Filter(filter) => filter.tool_names.as_deref(), - }); - for binding in self.exposed_name_map.values_mut() { - if binding.server_label != mcp_tool.server_label { + if let Some(allowed) = allowed_tool_names { + if !allowed.iter().any(|n| n == &binding.resolved_tool_name) { continue; } - if let Some(allowed_tool_names) = allowed_tool_names { - if !allowed_tool_names - .iter() - .any(|allowed_tool_name| allowed_tool_name == &binding.resolved_tool_name) - { - continue; - } - } - binding.approval_mode = approval_mode; } + binding.approval_mode = mode; } } @@ -455,235 +407,41 @@ impl<'a> McpToolSession<'a> { .collect() } - /// Look up the response format for a tool. - /// - /// Convenience method that returns `Passthrough` if the tool is not found. - pub fn tool_response_format(&self, tool_name: &str) -> ResponseFormat { - self.exposed_name_map - .get(tool_name) - .map(|binding| binding.response_format.clone()) - .unwrap_or(ResponseFormat::Passthrough) - } - - /// Build function-tool JSON payloads for upstream model calls. - pub fn build_function_tools_json(&self) -> Vec { - build_function_tools_json_with_names(&self.mcp_tools, Some(&self.exposed_name_by_qualified)) - } - - /// Build Chat API `Tool` structs for chat completions. - pub fn build_chat_function_tools(&self) -> Vec { - build_chat_function_tools_with_names(&self.mcp_tools, Some(&self.exposed_name_by_qualified)) - } - - /// Build Responses API `ResponseTool` structs. - pub fn build_response_tools(&self) -> Vec { - build_response_tools_with_names(&self.mcp_tools, Some(&self.exposed_name_by_qualified)) - } - - /// Build `mcp_list_tools` JSON for a specific server. - pub fn build_mcp_list_tools_json( - &self, - server_label: &str, - server_key: &str, - ) -> serde_json::Value { - let tools = self.list_tools_for_server(server_key); - build_mcp_list_tools_json(server_label, &tools) + /// Resolve an exposed tool name (post-alias) to its `QualifiedToolName`. + pub fn qualified_name_for_exposed(&self, tool_name: &str) -> Option { + let binding = self.exposed_name_map.get(tool_name)?; + Some(QualifiedToolName::new( + &binding.server_key, + &binding.resolved_tool_name, + )) } - /// Build typed `mcp_list_tools` output item for a specific server. - pub fn build_mcp_list_tools_item( + /// True when a tool name should be hidden because the underlying server + /// is internal-non-builtin and the user didn't explicitly declare it as a + /// function tool. + pub fn should_hide_function_call_like( &self, - server_label: &str, - server_key: &str, - ) -> openai_protocol::responses::ResponseOutputItem { - let tools = self.list_tools_for_server(server_key); - build_mcp_list_tools_item(server_label, &tools) - } - - /// Inject MCP metadata into a response output array. - /// - /// Standardized ordering: - /// 1. `mcp_list_tools` items (one per server) — prepended - /// 2. `tool_call_items` (mcp_call / web_search_call / etc.) — after list_tools - /// 3. Existing items (messages, etc.) — remain at end - /// - /// Test-only helper for legacy ordering assertions. - /// Production code should use `inject_client_visible_mcp_output_items`. - #[cfg(test)] - fn inject_mcp_output_items( - &self, - output: &mut Vec, - tool_call_items: Vec, - ) { - // Modify the vector in-place: take existing items, then rebuild - // with the correct ordering without allocating a temporary Vec. - let existing = std::mem::take(output); - output.reserve(self.mcp_servers.len() + tool_call_items.len() + existing.len()); - - // 1. mcp_list_tools items (one per server) - for binding in &self.mcp_servers { - output.push(self.build_mcp_list_tools_item(&binding.label, &binding.server_key)); - } - - // 2. Tool call items (mcp_call / web_search_call / etc.) - output.extend(tool_call_items); - - // 3. Existing items (messages, etc.) - output.extend(existing); - } - - /// Inject only client-visible MCP metadata and call items into response output. - /// - /// Visibility policy: - /// - Hide builtin `mcp_list_tools` (builtin tools surface under their own type) - /// - Hide internal non-builtin `mcp_list_tools` - /// - Hide internal non-builtin passthrough `mcp_call`/`mcp_approval_request` - /// - Keep builtin-routed call items visible - /// - Keep user-defined function calls visible even on name collisions - pub fn inject_client_visible_mcp_output_items( - &self, - output: &mut Vec, - tool_call_items: Vec, - user_function_names: &HashSet, - ) { - let existing = std::mem::take(output); - output.reserve(self.mcp_servers.len() + tool_call_items.len() + existing.len()); - - // Use mcp_servers (excludes builtin) to match streaming path behavior. - for binding in &self.mcp_servers { - if !self.is_internal_non_builtin_server_label(&binding.label) { - output.push(self.build_mcp_list_tools_item(&binding.label, &binding.server_key)); - } - } - - for item in tool_call_items { - if self.is_client_visible_output_item(&item, user_function_names) { - output.push(item); - } - } - - // Apply the same visibility policy to existing items (e.g. FunctionToolCall - // for executed MCP tools emitted by build_tool_response in the mixed - // function+MCP early-exit path). - for item in existing { - if self.is_client_visible_output_item(&item, user_function_names) { - output.push(item); - } - } - } - - fn is_client_visible_output_item( - &self, - item: &openai_protocol::responses::ResponseOutputItem, - user_function_names: &HashSet, - ) -> bool { - use openai_protocol::responses::ResponseOutputItem; - - match item { - ResponseOutputItem::McpListTools { server_label, .. } => { - !self.is_builtin_server_label(server_label) - && !self.is_internal_non_builtin_server_label(server_label) - } - ResponseOutputItem::McpCall { - server_label, name, .. - } - | ResponseOutputItem::McpApprovalRequest { - server_label, name, .. - } => !self.should_hide_mcp_call_like_by_label(name, server_label), - ResponseOutputItem::FunctionToolCall { name, .. } => { - !self.should_hide_function_call_like(name, user_function_names) - } - ResponseOutputItem::WebSearchCall { .. } - | ResponseOutputItem::CodeInterpreterCall { .. } - | ResponseOutputItem::FileSearchCall { .. } - | ResponseOutputItem::ImageGenerationCall { .. } - | ResponseOutputItem::ComputerCall { .. } - | ResponseOutputItem::ComputerCallOutput { .. } - | ResponseOutputItem::ShellCall { .. } - | ResponseOutputItem::ShellCallOutput { .. } - | ResponseOutputItem::ApplyPatchCall { .. } - | ResponseOutputItem::ApplyPatchCallOutput { .. } - | ResponseOutputItem::Message { .. } - | ResponseOutputItem::Reasoning { .. } - | ResponseOutputItem::Compaction { .. } - | ResponseOutputItem::LocalShellCall { .. } - | ResponseOutputItem::LocalShellCallOutput { .. } => true, - } - } - - /// Returns true when a JSON tool entry should be hidden from client-facing responses. - /// - /// This is used by OpenAI non-streaming response normalization, where tools are handled - /// as `serde_json::Value` payloads instead of typed `ResponseOutputItem`s. - pub fn should_hide_tool_json( - &self, - tool: &serde_json::Value, + name: &str, user_function_names: &HashSet, ) -> bool { - match tool.get("type").and_then(|value| value.as_str()) { - Some("function") => Self::function_tool_name_json(tool) - .is_some_and(|name| self.should_hide_function_call_like(name, user_function_names)), - // MCP tool entries are keyed by server metadata, so function-name collision - // handling does not apply to this arm. - Some("mcp") => tool - .get("server_label") - .and_then(|value| value.as_str()) - .is_some_and(|server_label| { - self.is_internal_non_builtin_server_label(server_label) - }), - _ => false, - } + self.is_internal_tool(name) && !user_function_names.contains(name) } - /// Returns true when a JSON output item should be hidden from client-facing responses. - /// - /// This keeps OpenAI non-streaming redaction aligned with session-level policy. - pub fn should_hide_output_item_json( - &self, - item: &serde_json::Value, - user_function_names: &HashSet, - ) -> bool { - match item.get("type").and_then(|value| value.as_str()) { - // mcp_list_tools is gateway-synthesized metadata. Hide for builtin servers - // (implementation detail) and internal non-builtin servers (privacy). - Some("mcp_list_tools") => item - .get("server_label") - .and_then(|value| value.as_str()) - .is_some_and(|server_label| { - self.is_builtin_server_label(server_label) - || self.is_internal_non_builtin_server_label(server_label) - }), - Some("mcp_call") | Some("mcp_approval_request") => { - let matches_internal_server = item - .get("server_label") - .and_then(|value| value.as_str()) - .is_some_and(|server_label| { - self.is_internal_non_builtin_server_label(server_label) - }); - - match item.get("name").and_then(|value| value.as_str()) { - Some(name) => { - self.should_hide_mcp_call_like_by_server_flag(name, matches_internal_server) - } - _ => matches_internal_server, - } - } - Some("function_call") | Some("function_tool_call") => item - .get("name") - .and_then(|value| value.as_str()) - .is_some_and(|name| self.should_hide_function_call_like(name, user_function_names)), - _ => false, + /// True when an `mcp_call`/`mcp_approval_request` item should be hidden. + /// `name` is the tool name; `server_label` is the user-facing label. + pub fn should_hide_mcp_call_like_by_label(&self, name: &str, server_label: &str) -> bool { + let matches_internal_server = self.is_internal_non_builtin_server_label(server_label); + if self.has_exposed_tool(name) { + self.is_internal_non_builtin_tool(name) + } else { + matches_internal_server } } - fn should_hide_mcp_call_like_by_label(&self, name: &str, server_label: &str) -> bool { - self.should_hide_mcp_call_like_by_server_flag( - name, - self.is_internal_non_builtin_server_label(server_label), - ) - } - - fn should_hide_mcp_call_like_by_server_flag( + /// Variant of `should_hide_mcp_call_like_by_label` that takes the + /// pre-resolved internal-server flag, used by JSON-shape filters that + /// already inspected the `server_label` field. + pub fn should_hide_mcp_call_like_by_server_flag( &self, name: &str, matches_internal_server: bool, @@ -695,28 +453,6 @@ impl<'a> McpToolSession<'a> { } } - fn should_hide_function_call_like( - &self, - name: &str, - user_function_names: &HashSet, - ) -> bool { - self.is_internal_tool(name) && !user_function_names.contains(name) - } - - fn function_tool_name_json(tool: &serde_json::Value) -> Option<&str> { - let tool_type = tool.get("type").and_then(|value| value.as_str()); - if tool_type != Some("function") { - return None; - } - tool.get("name") - .and_then(|value| value.as_str()) - .or_else(|| { - tool.get("function") - .and_then(|function| function.get("name")) - .and_then(|value| value.as_str()) - }) - } - fn build_exposed_function_tools( tools: &[ToolEntry], mcp_servers: &[McpServerBinding], @@ -786,7 +522,6 @@ impl<'a> McpToolSession<'a> { server_label, resolved_tool_name, is_builtin_routed, - response_format: entry.response_format.clone(), approval_mode: ApprovalMode::PolicyOnly, }, ); @@ -867,7 +602,7 @@ impl<'a> McpToolSession<'a> { ] .into_iter() .filter_map(|builtin_type| orchestrator.find_builtin_server(builtin_type)) - .map(|(server_key, tool_name, _)| QualifiedToolName::new(server_key, tool_name)) + .map(|(server_key, tool_name)| QualifiedToolName::new(server_key, tool_name)) .collect() } @@ -984,15 +719,6 @@ mod tests { assert_eq!(label, DEFAULT_SERVER_LABEL); } - #[test] - fn test_tool_response_format_default() { - let orchestrator = McpOrchestrator::new_test(); - let session = McpToolSession::new(&orchestrator, vec![], "test-request"); - - let format = session.tool_response_format("nonexistent"); - assert!(matches!(format, ResponseFormat::Passthrough)); - } - fn create_test_tool(name: &str) -> McpTool { use std::{borrow::Cow, sync::Arc}; @@ -1044,19 +770,7 @@ mod tests { }], "test-request", ); - session.configure_response_tools_approval(&[ResponseTool::Mcp( - openai_protocol::responses::McpTool { - server_url: Some("http://example.com/mcp".to_string()), - authorization: None, - headers: None, - server_label: "label1".to_string(), - server_description: None, - require_approval: Some(RequireApproval::Mode(RequireApprovalMode::Always)), - allowed_tools: None, - connector_id: None, - defer_loading: None, - }, - )]); + session.set_approval_mode("label1", None, ApprovalMode::Interactive); let result = session .execute_tool_result(ToolExecutionInput { @@ -1372,7 +1086,6 @@ mod tests { crate::inventory::ArgMapping::new() .with_override("enable_brave", serde_json::json!(false)), ), - ResponseFormat::WebSearchCall, ) .expect("alias registration should succeed"); @@ -1391,10 +1104,6 @@ mod tests { assert_eq!(session.mcp_tools().len(), 1); assert_eq!(session.mcp_tools()[0].tool_name(), "web_search"); assert_eq!(session.resolve_tool_server_label("web_search"), "brave"); - assert_eq!( - session.tool_response_format("web_search"), - ResponseFormat::WebSearchCall - ); let listed = session.list_tools_for_server("server1"); assert_eq!(listed.len(), 1); @@ -1413,13 +1122,7 @@ mod tests { )); orchestrator - .register_alias( - "web_search", - "server1", - "brave_web_search", - None, - ResponseFormat::WebSearchCall, - ) + .register_alias("web_search", "server1", "brave_web_search", None) .expect("alias registration should succeed"); let session = McpToolSession::new( @@ -1513,13 +1216,7 @@ mod tests { )); orchestrator - .register_alias( - "alias_search", - "internal-server", - "internal_search", - None, - ResponseFormat::Passthrough, - ) + .register_alias("alias_search", "internal-server", "internal_search", None) .expect("alias registration should succeed"); let session = McpToolSession::new( @@ -1598,122 +1295,6 @@ mod tests { ); } - /// Verify that `inject_mcp_output_items` produces the exact ordering: - /// 1. mcp_list_tools items (one per server, in server order) - /// 2. tool_call_items (in their original order) - /// 3. existing output items (in their original order) - /// - /// This is a regression test so future perf refactors cannot - /// accidentally change the output ordering contract. - #[test] - fn test_inject_mcp_output_items_ordering() { - use openai_protocol::responses::ResponseOutputItem; - - let orchestrator = McpOrchestrator::new_test(); - - // Register one tool per server so build_mcp_list_tools_item has - // something to return. - orchestrator - .tool_inventory() - .insert_entry(ToolEntry::from_server_tool( - "srv_a", - create_test_tool("tool_a"), - )); - orchestrator - .tool_inventory() - .insert_entry(ToolEntry::from_server_tool( - "srv_b", - create_test_tool("tool_b"), - )); - - let session = McpToolSession::new( - &orchestrator, - vec![ - McpServerBinding { - label: "Server A".to_string(), - server_key: "srv_a".to_string(), - allowed_tools: None, - }, - McpServerBinding { - label: "Server B".to_string(), - server_key: "srv_b".to_string(), - allowed_tools: None, - }, - ], - "test-ordering", - ); - - // Pre-existing output items (e.g. assistant message). - let existing_1 = ResponseOutputItem::Message { - id: "msg_existing_1".to_string(), - role: "assistant".to_string(), - content: vec![], - status: "completed".to_string(), - phase: None, - }; - let existing_2 = ResponseOutputItem::Message { - id: "msg_existing_2".to_string(), - role: "assistant".to_string(), - content: vec![], - status: "completed".to_string(), - phase: None, - }; - - // Tool call items injected by the router. - let call_1 = ResponseOutputItem::McpCall { - id: "call_1".to_string(), - status: "completed".to_string(), - approval_request_id: None, - arguments: "{}".to_string(), - error: None, - name: "tool_a".to_string(), - output: "result_a".to_string(), - server_label: "Server A".to_string(), - }; - let call_2 = ResponseOutputItem::McpCall { - id: "call_2".to_string(), - status: "completed".to_string(), - approval_request_id: None, - arguments: "{}".to_string(), - error: None, - name: "tool_b".to_string(), - output: "result_b".to_string(), - server_label: "Server B".to_string(), - }; - - let mut output = vec![existing_1, existing_2]; - let tool_call_items = vec![call_1, call_2]; - - session.inject_mcp_output_items(&mut output, tool_call_items); - - // Expected ordering: 2 mcp_list_tools + 2 mcp_call + 2 messages = 6 - assert_eq!(output.len(), 6, "expected 6 items in output"); - - // Serialize to JSON values for easier field-level assertions. - let items: Vec = output - .iter() - .map(|item| serde_json::to_value(item).expect("serialization failed")) - .collect(); - - // [0..2] mcp_list_tools — one per server, in server order - assert_eq!(items[0]["type"], "mcp_list_tools"); - assert_eq!(items[0]["server_label"], "Server A"); - assert_eq!(items[1]["type"], "mcp_list_tools"); - assert_eq!(items[1]["server_label"], "Server B"); - - // [2..4] tool call items in original order - assert_eq!(items[2]["type"], "mcp_call"); - assert_eq!(items[2]["id"], "call_1"); - assert_eq!(items[3]["type"], "mcp_call"); - assert_eq!(items[3]["id"], "call_2"); - - // [4..6] existing items in original order - assert_eq!(items[4]["type"], "message"); - assert_eq!(items[4]["id"], "msg_existing_1"); - assert_eq!(items[5]["type"], "message"); - assert_eq!(items[5]["id"], "msg_existing_2"); - } - #[test] fn test_allowed_tools_filters_only_target_server() { let orchestrator = McpOrchestrator::new_test(); diff --git a/crates/mcp/src/inventory/mod.rs b/crates/mcp/src/inventory/mod.rs index 8f4b8b3fa..1eea0c36f 100644 --- a/crates/mcp/src/inventory/mod.rs +++ b/crates/mcp/src/inventory/mod.rs @@ -9,4 +9,6 @@ pub mod index; pub mod types; pub use index::{IndexCounts, ToolInventory}; -pub use types::{AliasTarget, ArgMapping, QualifiedToolName, ToolCategory, ToolEntry}; +pub use types::{ + AliasTarget, ArgMapping, QualifiedToolName, ToolCategory, ToolEntry, ALIAS_SERVER_KEY, +}; diff --git a/crates/mcp/src/inventory/types.rs b/crates/mcp/src/inventory/types.rs index 92feebab7..0cf66556e 100644 --- a/crates/mcp/src/inventory/types.rs +++ b/crates/mcp/src/inventory/types.rs @@ -5,9 +5,7 @@ use std::{fmt, sync::Arc, time::Duration}; use serde::{Deserialize, Serialize}; use tokio::time::Instant; -use crate::{ - annotations::ToolAnnotations, core::config::Tool, tenant::TenantId, transform::ResponseFormat, -}; +use crate::{annotations::ToolAnnotations, core::config::Tool, tenant::TenantId}; /// Category of a tool for filtering and visibility control. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)] @@ -21,6 +19,12 @@ pub enum ToolCategory { Builtin, } +/// Synthetic `server_key` used by alias entries created via +/// `McpOrchestrator::register_alias`. Exported so gateway-side code that +/// indexes by `QualifiedToolName` (e.g. the response-format registry) can +/// reconstruct the same key without hardcoding the literal. +pub const ALIAS_SERVER_KEY: &str = "alias"; + /// Unique tool identifier: `server_key:tool_name`. /// /// Uses `Arc` internally for cheap cloning in hot paths. @@ -138,8 +142,6 @@ pub struct ToolEntry { pub arg_mapping: Option, pub cached_at: Instant, pub ttl: Option, - /// Response format for transforming MCP results to API-specific formats. - pub response_format: ResponseFormat, } impl ToolEntry { @@ -154,7 +156,6 @@ impl ToolEntry { arg_mapping: None, cached_at: Instant::now(), ttl: None, - response_format: ResponseFormat::default(), } } @@ -200,12 +201,6 @@ impl ToolEntry { self } - #[must_use] - pub fn with_response_format(mut self, response_format: ResponseFormat) -> Self { - self.response_format = response_format; - self - } - pub fn is_expired(&self) -> bool { self.ttl .map(|ttl| self.cached_at.elapsed() > ttl) diff --git a/crates/mcp/src/lib.rs b/crates/mcp/src/lib.rs index 352cd95ec..66107fad4 100644 --- a/crates/mcp/src/lib.rs +++ b/crates/mcp/src/lib.rs @@ -1,63 +1,41 @@ //! Model Context Protocol (MCP) client implementation. //! -//! ## Modules +//! Crate is OpenAI-protocol-free; gateway-side adapter logic lives in +//! `model_gateway::routers::common::openai_bridge`. //! -//! - [`core`]: MCP client infrastructure (manager, config, connections) -//! - [`inventory`]: Tool storage and indexing -//! - [`approval`]: Approval system for tool execution -//! -//! ## Shared Types -//! -//! - [`ToolAnnotations`]: Tool behavior hints (read_only, destructive, etc.) -//! - [`TenantContext`]: Per-tenant isolation and configuration +//! Modules: +//! - [`core`] — orchestrator, sessions, transports, oauth, config +//! - [`inventory`] — tool registry + qualified naming +//! - [`approval`] — interactive/policy approval engine +//! - [`annotations`], [`tenant`], [`error`] — shared cross-module types -// Shared types (used across modules) pub mod annotations; -pub mod error; -pub mod tenant; -pub mod transform; - -// Subsystems pub mod approval; pub mod core; +pub mod error; pub mod inventory; -pub mod responses_bridge; - -// Backward-compatible re-exports (old module paths) -// These allow `mcp::config::*` to continue working -pub use core::{config, pool as connection_pool}; +pub mod tenant; // Re-export from core pub use core::{ ArgMappingConfig, BuiltinToolType, ConfigValidationError, HandlerRequestContext, LatencySnapshot, McpConfig, McpMetrics, McpOrchestrator, McpRequestContext, McpServerBinding, McpServerConfig, McpToolSession, McpTransport, MetricsSnapshot, PendingToolExecution, PolicyConfig, PolicyDecisionConfig, PoolKey, RefreshRequest, ResponseFormatConfig, - ServerPolicyConfig, SmgClientHandler, Tool, ToolCallResult, ToolConfig, ToolExecutionInput, + ServerPolicyConfig, SmgClientHandler, Tool, ToolConfig, ToolExecutionInput, ToolExecutionOutput, ToolExecutionResult, TrustLevelConfig, DEFAULT_SERVER_LABEL, }; // Re-export shared types pub use annotations::{AnnotationType, ToolAnnotations}; -// Re-export from approval -pub use approval::{ - ApprovalDecision, ApprovalKey, ApprovalManager, ApprovalMode, ApprovalOutcome, ApprovalParams, - AuditEntry, AuditLog, DecisionResult, DecisionSource, McpApprovalRequest, McpApprovalResponse, - PolicyDecision, PolicyEngine, PolicyRule, RuleCondition, RulePattern, ServerPolicy, TrustLevel, -}; +// Re-export from approval. Only `ApprovalMode` is surfaced at the crate root — +// production gateway code never instantiates `ApprovalManager`, `PolicyEngine`, +// `AuditLog`, etc. directly (those are reached through `McpOrchestrator`). +// The remaining symbols stay accessible via `smg_mcp::approval::*` for code +// that genuinely needs them, but they no longer pollute the flat root namespace. +pub use approval::ApprovalMode; pub use error::{ApprovalError, McpError, McpResult}; // Re-export from inventory pub use inventory::{ AliasTarget, ArgMapping, QualifiedToolName, ToolCategory, ToolEntry, ToolInventory, }; -pub use responses_bridge::{ - build_chat_function_tools, build_chat_function_tools_with_names, build_function_tools_json, - build_function_tools_json_with_names, build_mcp_list_tools_item, build_mcp_list_tools_json, - build_mcp_tool_infos, build_response_tools, build_response_tools_with_names, -}; pub use tenant::{SessionId, TenantContext, TenantId}; -// Re-export from transform -pub use transform::{ - apply_hosted_tool_overrides, compact_image_generation_output, - extract_embedded_openai_responses, extract_hosted_tool_overrides, mcp_response_item_id, - ResponseFormat, ResponseTransformer, -}; diff --git a/crates/mcp/src/responses_bridge.rs b/crates/mcp/src/responses_bridge.rs deleted file mode 100644 index 7cd956f4b..000000000 --- a/crates/mcp/src/responses_bridge.rs +++ /dev/null @@ -1,177 +0,0 @@ -//! Shared builders for Responses/Chat tool payloads derived from MCP tool inventory. -//! -//! This module centralizes conversion logic that was previously duplicated in routers: -//! - MCP ToolEntry -> function tool JSON (for upstream model calls) -//! - MCP ToolEntry -> chat/common Tool structs -//! - MCP ToolEntry -> Responses ResponseTool structs -//! - MCP ToolEntry list -> mcp_list_tools output item payloads -//! -//! # Schema cloning -//! -//! Each `ToolEntry` stores its JSON Schema as `Arc`, which enables -//! cheap sharing across the inventory. The downstream protocol types (`Function`, -//! `McpToolInfo`, and `serde_json::Value`) all require an *owned* `serde_json::Map` -//! inside `Value::Object`, so every builder must deep-clone the schema map once per -//! tool per call. This is intentional -- the clone happens O(tools) times per -//! request and schema maps are typically small (a handful of properties). If -//! profiling ever shows this to be a bottleneck, consider caching the materialised -//! `Value::Object` alongside the `Arc` in `ToolEntry`. - -use openai_protocol::{ - common::{Function, Tool}, - responses::{generate_id, FunctionTool, McpToolInfo, ResponseOutputItem, ResponseTool}, -}; -use serde_json::{json, Value}; - -use crate::inventory::{QualifiedToolName, ToolEntry}; - -/// Materialise a `serde_json::Map` reference into an owned `Value::Object`. -/// -/// This deep-clones the map. All builder functions in this module need an -/// owned `Value` because the downstream protocol structs (`Function.parameters`, -/// `McpToolInfo.input_schema`) are `Value`-typed and will be serialised into the -/// API response. See the module-level "Schema cloning" docs for why this is -/// unavoidable. -#[inline] -fn schema_to_value(schema: &serde_json::Map) -> Value { - Value::Object(schema.clone()) -} - -fn resolved_name_for_entry<'a>( - entry: &'a ToolEntry, - exposed_names: Option<&'a std::collections::HashMap>, -) -> &'a str { - exposed_names - .and_then(|m| m.get(&entry.qualified_name)) - .map(|s| s.as_str()) - .unwrap_or_else(|| entry.tool_name()) -} - -/// Resolved (name, description, &input_schema) triples from MCP tool entries. -/// -/// This is the shared extraction logic used by the JSON, Chat, and Responses -/// builder functions so that name-resolution lives in one place. The schema is -/// returned by reference; callers clone when they need an owned `Value`. -fn resolved_tool_fields<'a>( - entries: &'a [ToolEntry], - exposed_names: Option<&'a std::collections::HashMap>, -) -> impl Iterator, &'a serde_json::Map)> + 'a { - entries.iter().map(move |entry| { - let name = resolved_name_for_entry(entry, exposed_names); - let description = entry.tool.description.as_deref(); - (name, description, &*entry.tool.input_schema) - }) -} - -/// Build function-tool JSON payloads from MCP tool entries. -/// -/// These are used when routers expose MCP tools as function tools to upstream model APIs. -pub fn build_function_tools_json(entries: &[ToolEntry]) -> Vec { - build_function_tools_json_with_names(entries, None) -} - -/// Build function-tool JSON payloads from MCP tool entries with optional exposed names. -pub fn build_function_tools_json_with_names( - entries: &[ToolEntry], - exposed_names: Option<&std::collections::HashMap>, -) -> Vec { - resolved_tool_fields(entries, exposed_names) - .map(|(name, description, parameters)| { - json!({ - "type": "function", - "name": name, - "description": description, - "parameters": schema_to_value(parameters) - }) - }) - .collect() -} - -/// Build Chat API function tools from MCP tool entries. -pub fn build_chat_function_tools(entries: &[ToolEntry]) -> Vec { - build_chat_function_tools_with_names(entries, None) -} - -/// Build Chat API function tools from MCP tool entries with optional exposed names. -pub fn build_chat_function_tools_with_names( - entries: &[ToolEntry], - exposed_names: Option<&std::collections::HashMap>, -) -> Vec { - resolved_tool_fields(entries, exposed_names) - .map(|(name, description, parameters)| Tool { - tool_type: "function".to_string(), - function: Function { - name: name.to_string(), - description: description.map(|d| d.to_string()), - parameters: schema_to_value(parameters), - strict: None, - }, - }) - .collect() -} - -/// Build Responses API function tools from MCP tool entries. -/// -/// MCP tools are exposed to the model as function tools in the Responses API, -/// so these serialize as `{"type": "function", ...}` tool entries. -pub fn build_response_tools(entries: &[ToolEntry]) -> Vec { - build_response_tools_with_names(entries, None) -} - -/// Build Responses API function tools from MCP tool entries with optional exposed names. -pub fn build_response_tools_with_names( - entries: &[ToolEntry], - exposed_names: Option<&std::collections::HashMap>, -) -> Vec { - resolved_tool_fields(entries, exposed_names) - .map(|(name, description, parameters)| { - ResponseTool::Function(FunctionTool { - function: Function { - name: name.to_string(), - description: description.map(|d| d.to_string()), - parameters: schema_to_value(parameters), - strict: None, - }, - }) - }) - .collect() -} - -/// Build MCP tool infos used by `mcp_list_tools` output items. -pub fn build_mcp_tool_infos(entries: &[ToolEntry]) -> Vec { - entries - .iter() - .map(|entry| McpToolInfo { - name: entry.tool_name().to_string(), - description: entry.tool.description.as_ref().map(|d| d.to_string()), - input_schema: schema_to_value(&entry.tool.input_schema), - annotations: entry - .tool - .annotations - .as_ref() - .and_then(|a| serde_json::to_value(a).ok()), - }) - .collect() -} - -/// Build a typed `mcp_list_tools` output item. -pub fn build_mcp_list_tools_item(server_label: &str, entries: &[ToolEntry]) -> ResponseOutputItem { - ResponseOutputItem::McpListTools { - id: generate_id("mcpl"), - server_label: server_label.to_string(), - tools: build_mcp_tool_infos(entries), - // T11: `error` is populated when the MCP server failed to list tools; - // this constructor synthesizes a successful listing from tool entries, - // so no error is attached. - error: None, - } -} - -/// Build a JSON `mcp_list_tools` output item payload. -/// -/// Useful for routers that build/manipulate raw JSON responses. -pub fn build_mcp_list_tools_json(server_label: &str, entries: &[ToolEntry]) -> Value { - serde_json::to_value(build_mcp_list_tools_item(server_label, entries)).unwrap_or_else( - |_| json!({ "type": "mcp_list_tools", "server_label": server_label, "tools": [] }), - ) -} diff --git a/crates/mcp/src/transform/mod.rs b/crates/mcp/src/transform/mod.rs deleted file mode 100644 index a936e2b8e..000000000 --- a/crates/mcp/src/transform/mod.rs +++ /dev/null @@ -1,31 +0,0 @@ -//! Response transformation for MCP to API-specific formats. -//! -//! This module provides transformation from MCP `CallToolResult` responses -//! to OpenAI Responses API formats (web_search_call, code_interpreter_call, etc.). -//! -//! # Example -//! -//! ```ignore -//! use smg_mcp::transform::{ResponseFormat, ResponseTransformer}; -//! -//! let mcp_result = serde_json::json!({"results": [{"url": "https://example.com"}]}); -//! let output = ResponseTransformer::transform( -//! &mcp_result, -//! &ResponseFormat::WebSearchCall, -//! "call-123", -//! "brave", -//! "web_search", -//! "{}", -//! ); -//! ``` - -mod overrides; -mod transformer; -mod types; - -pub use overrides::{apply_hosted_tool_overrides, extract_hosted_tool_overrides}; -pub use transformer::{ - compact_image_generation_output, extract_embedded_openai_responses, mcp_response_item_id, - ResponseTransformer, -}; -pub use types::ResponseFormat; diff --git a/e2e_test/responses/test_image_generation.py b/e2e_test/responses/test_image_generation.py index 125a6ec0a..e2c7848b8 100644 --- a/e2e_test/responses/test_image_generation.py +++ b/e2e_test/responses/test_image_generation.py @@ -33,6 +33,7 @@ from __future__ import annotations +import json import logging import httpx @@ -255,22 +256,6 @@ def scoped_idx(evt: str) -> int: ) -def _extract_conversation_id(resp) -> str | None: - """Normalise the conversation id across SDK variations. - - The OpenAI Responses SDK surfaces the conversation id as either - ``resp.conversation_id`` (a plain string) or ``resp.conversation`` (an - object with an ``.id`` attribute), and older releases expose neither. - This helper collapses those cases to a single ``str | None``. - """ - conv_attr = getattr(resp, "conversation_id", None) or getattr(resp, "conversation", None) - if conv_attr is None: - return None - if isinstance(conv_attr, str): - return conv_attr - return getattr(conv_attr, "id", None) - - # ============================================================================= # Shared test mix-in body # ============================================================================= @@ -427,9 +412,22 @@ def test_image_generation_user_forwarded_to_mcp(self, request, image_gen_tool_ar ) def test_image_generation_compactor_strips_base64(self, request, image_gen_tool_args) -> None: - """Multi-turn replay: base64 payload must not survive into stored context.""" + """Multi-turn replay: base64 payload must not survive into stored context. + + Creates a conversation explicitly and pins the first response to it, + so the GET /v1/conversations/{id}/items round-trip actually exercises + the persistence path. Without an explicit `conversation=...` the + Responses API stores to response chains only — no conversation rows + are written and the assertion would have nothing to inspect. + """ gateway, client, mock_mcp, model = self._ctx(request) + # Create the conversation up front so the response gets linked into + # `conversation_items` (the storage path that + # `compact_image_generation_outputs_json` is wired into). + conv = client.conversations.create() + conversation_id = conv.id + resp1 = client.responses.create( model=model, input=_IMAGE_GEN_PROMPT, @@ -437,23 +435,17 @@ def test_image_generation_compactor_strips_base64(self, request, image_gen_tool_ tool_choice=_FORCED_TOOL_CHOICE, stream=False, store=True, + conversation=conversation_id, ) assert resp1.error is None, f"Turn 1 error: {resp1.error}" # Sanity-check that the tool actually ran. assert _find_image_generation_call(resp1.output) is not None, ( "Turn 1 response did not contain an image_generation_call item; " - "the compactor-replay assertion would be vacuous. " + "the compactor-strip assertion would be vacuous. " f"output types: {[getattr(i, 'type', None) for i in resp1.output or []]}" ) - conversation_id = _extract_conversation_id(resp1) - if not conversation_id: - pytest.skip( - "Gateway did not expose a conversation_id on the first response " - "— compactor-replay assertion depends on stored history." - ) - api_key = client.api_key with httpx.Client(timeout=10.0) as http: items_resp = http.get( @@ -464,19 +456,38 @@ def test_image_generation_compactor_strips_base64(self, request, image_gen_tool_ f"Failed to list conversation items: {items_resp.status_code} {items_resp.text}" ) - # Positive persistence guard. + # Positive persistence guard: confirm an image_generation_call item + # was actually persisted. Otherwise the strip assertion below would + # pass vacuously. items_data = items_resp.json() stored_items = items_data.get("data") or items_data.get("items") or [] - assert stored_items, ( - "Conversation persisted no items — compactor-strip assertion would be " - f"vacuous. Full response: {items_data!r}" + image_items = [item for item in stored_items if item.get("type") == "image_generation_call"] + assert image_items, ( + "No image_generation_call item in stored conversation history; " + "compactor-strip assertion would be vacuous. " + f"stored types: {[item.get('type') for item in stored_items]}" ) - payload = items_resp.text - assert mock_mcp.image_generation_png_base64 not in payload, ( - "Compactor failed to strip base64 payload from stored conversation " - "history; replay would re-ship the image bytes to the model." - ) + # Assert directly on the persisted image_generation_call shape rather + # than the entire payload text. The model's free-form assistant reply + # may legitimately echo the bytes back ("Base64 PNG data: iVBORw…") + # when it sees them in the function_call_output, which is unrelated + # to whether the compactor stripped `result` from the stored + # image_generation_call item itself. + for item in image_items: + content = item.get("content") or {} + assert "result" not in item, ( + f"Compactor failed to strip top-level `result` from stored " + f"image_generation_call: {item!r}" + ) + assert "result" not in content, ( + f"Compactor failed to strip nested `result` from stored " + f"image_generation_call.content: {item!r}" + ) + assert mock_mcp.image_generation_png_base64 not in json.dumps(item), ( + f"Base64 payload survived inside stored image_generation_call " + f"item; replay would re-ship the bytes to the model: {item!r}" + ) # ============================================================================= diff --git a/model_gateway/src/app_context.rs b/model_gateway/src/app_context.rs index f53898e3c..18c9e8b0c 100644 --- a/model_gateway/src/app_context.rs +++ b/model_gateway/src/app_context.rs @@ -23,8 +23,8 @@ use crate::{ observability::inflight_tracker::InFlightRequestTracker, policies::PolicyRegistry, routers::{ - grpc::multimodal::MultimodalConfigRegistry, openai::realtime::RealtimeRegistry, - router_manager::RouterManager, + common::openai_bridge::FormatRegistry, grpc::multimodal::MultimodalConfigRegistry, + openai::realtime::RealtimeRegistry, router_manager::RouterManager, }, wasm::{config::WasmRuntimeConfig, module_manager::WasmModuleManager}, worker::{KvEventMonitor, WorkerMonitor, WorkerRegistry, WorkerService}, @@ -73,6 +73,7 @@ pub struct AppContext { pub worker_job_queue: Arc>>, pub workflow_engines: Arc>, pub mcp_orchestrator: Arc>>, + pub mcp_format_registry: FormatRegistry, pub skill_service: Option>, pub wasm_manager: Option>, pub worker_service: Arc, @@ -112,6 +113,7 @@ pub struct AppContextBuilder { worker_job_queue: Option>>>, workflow_engines: Option>>, mcp_orchestrator: Option>>>, + mcp_format_registry: Option, skill_service: Option>, wasm_manager: Option>, kv_event_monitor: Option>, @@ -167,6 +169,7 @@ impl AppContextBuilder { worker_job_queue: None, workflow_engines: None, mcp_orchestrator: None, + mcp_format_registry: None, skill_service: None, wasm_manager: None, kv_event_monitor: None, @@ -284,6 +287,11 @@ impl AppContextBuilder { self } + pub fn mcp_format_registry(mut self, registry: FormatRegistry) -> Self { + self.mcp_format_registry = Some(registry); + self + } + pub fn wasm_manager(mut self, wasm_manager: Option>) -> Self { self.wasm_manager = wasm_manager; self @@ -396,6 +404,7 @@ impl AppContextBuilder { mcp_orchestrator: self .mcp_orchestrator .ok_or(AppContextBuildError::MissingField("mcp_orchestrator"))?, + mcp_format_registry: self.mcp_format_registry.unwrap_or_default(), skill_service: self.skill_service, wasm_manager: self.wasm_manager, worker_service, @@ -667,6 +676,7 @@ impl AppContextBuilder { .map_err(|_| "Failed to set MCP orchestrator in OnceLock".to_string())?; self.mcp_orchestrator = Some(mcp_orchestrator_lock); + self.mcp_format_registry = Some(FormatRegistry::new()); Ok(self) } diff --git a/model_gateway/src/routers/anthropic/context.rs b/model_gateway/src/routers/anthropic/context.rs index 6efe9bc4d..e3edfcab2 100644 --- a/model_gateway/src/routers/anthropic/context.rs +++ b/model_gateway/src/routers/anthropic/context.rs @@ -11,6 +11,7 @@ use smg_mcp::{McpOrchestrator, McpServerBinding}; use crate::{ middleware::TenantRequestMeta, + routers::common::openai_bridge::FormatRegistry, worker::{Worker, WorkerRegistry}, }; @@ -18,6 +19,7 @@ use crate::{ #[derive(Clone)] pub(crate) struct RouterContext { pub mcp_orchestrator: Arc, + pub mcp_format_registry: FormatRegistry, pub http_client: reqwest::Client, pub worker_registry: Arc, pub request_timeout: Duration, diff --git a/model_gateway/src/routers/anthropic/router.rs b/model_gateway/src/routers/anthropic/router.rs index 2b1cae62d..c75abfca2 100644 --- a/model_gateway/src/routers/anthropic/router.rs +++ b/model_gateway/src/routers/anthropic/router.rs @@ -62,6 +62,7 @@ impl AnthropicRouter { let router_ctx = RouterContext { mcp_orchestrator, + mcp_format_registry: context.mcp_format_registry.clone(), http_client: context.client.clone(), worker_registry: context.worker_registry.clone(), request_timeout, @@ -116,8 +117,13 @@ impl RouterTrait for AnthropicRouter { }) .collect(); - match mcp_utils::ensure_mcp_servers(&self.router_ctx.mcp_orchestrator, &inputs, &[]) - .await + match mcp_utils::ensure_mcp_servers( + &self.router_ctx.mcp_orchestrator, + &self.router_ctx.mcp_format_registry, + &inputs, + &[], + ) + .await { Some(servers) => { info!( diff --git a/model_gateway/src/routers/common/mcp_utils.rs b/model_gateway/src/routers/common/mcp_utils.rs index dd872f51b..d9ba42ced 100644 --- a/model_gateway/src/routers/common/mcp_utils.rs +++ b/model_gateway/src/routers/common/mcp_utils.rs @@ -7,12 +7,13 @@ use std::{ use openai_protocol::responses::{McpAllowedTools, ResponseTool, ResponsesRequest}; use serde_json::{json, Value}; -use smg_mcp::{ - apply_hosted_tool_overrides, extract_hosted_tool_overrides, BuiltinToolType, McpOrchestrator, - McpServerBinding, McpServerConfig, McpTransport, ResponseFormat, -}; +use smg_mcp::{BuiltinToolType, McpOrchestrator, McpServerBinding, McpServerConfig, McpTransport}; use tracing::{debug, warn}; +use crate::routers::common::openai_bridge::{ + apply_hosted_tool_overrides, extract_hosted_tool_overrides, FormatRegistry, ResponseFormat, +}; + /// Default maximum tool loop iterations (safety limit). pub const DEFAULT_MAX_ITERATIONS: usize = 10; @@ -66,6 +67,7 @@ pub struct McpServerInput { /// Returns a list of [`McpServerBinding`]s for successfully connected servers. pub async fn connect_mcp_servers( mcp_orchestrator: &Arc, + format_registry: &FormatRegistry, inputs: &[McpServerInput], ) -> Vec { let mut mcp_servers: Vec = Vec::new(); @@ -117,8 +119,15 @@ pub async fn connect_mcp_servers( let server_key = McpOrchestrator::server_key(&server_config); - match mcp_orchestrator.connect_dynamic_server(server_config).await { + match mcp_orchestrator + .connect_dynamic_server(server_config.clone()) + .await + { Ok(_) => { + // Mirror the orchestrator's tool-config + builtin-format passes + // into the gateway-side FormatRegistry. See + // routers/common/openai_bridge/format_registry.rs. + format_registry.populate_from_server_config(&server_config); if !mcp_servers.iter().any(|b| b.server_key == server_key) { mcp_servers.push(McpServerBinding { label: input.label.clone(), @@ -175,6 +184,7 @@ pub struct BuiltinToolRouting { /// Empty if no built-in tools are found or none have MCP server configurations. pub fn collect_builtin_routing( mcp_orchestrator: &Arc, + format_registry: &FormatRegistry, tools: Option<&[ResponseTool]>, ) -> Vec { let Some(tools) = tools else { @@ -191,9 +201,7 @@ pub fn collect_builtin_routing( _ => continue, }; - if let Some((server_name, tool_name, response_format)) = - mcp_orchestrator.find_builtin_server(builtin_type) - { + if let Some((server_name, tool_name)) = mcp_orchestrator.find_builtin_server(builtin_type) { debug!( builtin_type = ?builtin_type, server = %server_name, @@ -201,6 +209,11 @@ pub fn collect_builtin_routing( "Found MCP server for built-in tool type" ); + // ResponseFormat for the routed (server, tool) pair lives in the + // gateway-side registry — populated at server-registration time + // (see FormatRegistry::populate_from_server_config). + let response_format = format_registry.lookup_by_names(&server_name, &tool_name); + routing.push(BuiltinToolRouting { builtin_type, server_name, @@ -256,14 +269,15 @@ pub(crate) fn collect_user_function_names(request: &ResponsesRequest) -> HashSet /// Returns `Some(servers)` if at least one server is available, `None` otherwise. pub async fn ensure_mcp_servers( orchestrator: &Arc, + format_registry: &FormatRegistry, inputs: &[McpServerInput], builtin_types: &[BuiltinToolType], ) -> Option> { - let mut mcp_servers = connect_mcp_servers(orchestrator, inputs).await; + let mut mcp_servers = connect_mcp_servers(orchestrator, format_registry, inputs).await; // Add builtin tool routing servers for &builtin_type in builtin_types { - if let Some((server_name, tool_name, _)) = orchestrator.find_builtin_server(builtin_type) { + if let Some((server_name, tool_name)) = orchestrator.find_builtin_server(builtin_type) { debug!( builtin_type = ?builtin_type, server = %server_name, @@ -298,6 +312,7 @@ pub async fn ensure_mcp_servers( /// then delegates to [`ensure_mcp_servers`]. pub async fn ensure_request_mcp_client( mcp_orchestrator: &Arc, + format_registry: &FormatRegistry, tools: &[ResponseTool], ) -> Option> { let inputs: Vec = tools @@ -320,7 +335,7 @@ pub async fn ensure_request_mcp_client( let builtin_types = extract_builtin_types(tools); - ensure_mcp_servers(mcp_orchestrator, &inputs, &builtin_types).await + ensure_mcp_servers(mcp_orchestrator, format_registry, &inputs, &builtin_types).await } /// Forward the caller's `user` identifier into hosted-tool dispatch arguments. @@ -343,9 +358,9 @@ pub async fn ensure_request_mcp_client( /// - No-op if `arguments` already contains a `user` key — model-supplied /// values win over the request-level identifier. /// - Otherwise inserts `arguments["user"] = json!(user)`. -pub(crate) fn inject_user_into_hosted_args( +fn inject_user_into_hosted_args( arguments: &mut Value, - response_format: &ResponseFormat, + response_format: ResponseFormat, user: Option<&str>, ) { if response_format.to_builtin_tool_type().is_none() { @@ -394,7 +409,7 @@ pub(crate) fn inject_user_into_hosted_args( /// a single chokepoint. pub(crate) fn prepare_hosted_dispatch_args( arguments: &mut Value, - response_format: &ResponseFormat, + response_format: ResponseFormat, request_tools: &[ResponseTool], request_user: Option<&str>, ) { @@ -422,32 +437,36 @@ mod tests { use super::*; - /// Create a test orchestrator with a built-in server configuration - async fn create_test_orchestrator_with_builtin() -> Arc { + /// Create a test orchestrator with a built-in server configuration, + /// plus a `FormatRegistry` mirroring the same per-tool ResponseFormat + /// values that production code populates at server-registration time. + async fn create_test_orchestrator_with_builtin() -> (Arc, FormatRegistry) { let mut tools_config = HashMap::new(); tools_config.insert( "web_search".to_string(), ToolConfig { - response_format: ResponseFormatConfig::WebSearchCall, + response_format: Some(ResponseFormatConfig::WebSearchCall), ..Default::default() }, ); + let server = McpServerConfig { + name: "search-server".to_string(), + transport: McpTransport::Streamable { + url: "http://localhost:9999/mcp".to_string(), + token: None, + headers: HashMap::new(), + }, + proxy: None, + required: false, + tools: Some(tools_config), + builtin_type: Some(BuiltinToolType::WebSearchPreview), + builtin_tool_name: Some("web_search".to_string()), + internal: false, + }; + let config = McpConfig { - servers: vec![McpServerConfig { - name: "search-server".to_string(), - transport: McpTransport::Streamable { - url: "http://localhost:9999/mcp".to_string(), - token: None, - headers: HashMap::new(), - }, - proxy: None, - required: false, - tools: Some(tools_config), - builtin_type: Some(BuiltinToolType::WebSearchPreview), - builtin_tool_name: Some("web_search".to_string()), - internal: false, - }], + servers: vec![server.clone()], pool: Default::default(), proxy: None, warmup: Vec::new(), @@ -455,12 +474,18 @@ mod tests { policy: Default::default(), }; + let registry = FormatRegistry::new(); + registry.populate_from_server_config(&server); + // Note: This will fail to connect but still create the orchestrator with config - Arc::new(McpOrchestrator::new(config).await.unwrap()) + ( + Arc::new(McpOrchestrator::new(config).await.unwrap()), + registry, + ) } /// Create a test orchestrator without built-in server configuration - async fn create_test_orchestrator_no_builtin() -> Arc { + async fn create_test_orchestrator_no_builtin() -> (Arc, FormatRegistry) { let config = McpConfig { servers: vec![], pool: Default::default(), @@ -470,18 +495,21 @@ mod tests { policy: Default::default(), }; - Arc::new(McpOrchestrator::new(config).await.unwrap()) + ( + Arc::new(McpOrchestrator::new(config).await.unwrap()), + FormatRegistry::new(), + ) } #[tokio::test] async fn test_collect_builtin_routing_with_configured_server() { - let orchestrator = create_test_orchestrator_with_builtin().await; + let (orchestrator, format_registry) = create_test_orchestrator_with_builtin().await; let tools = vec![ResponseTool::WebSearchPreview( WebSearchPreviewTool::default(), )]; - let routing = collect_builtin_routing(&orchestrator, Some(&tools)); + let routing = collect_builtin_routing(&orchestrator, &format_registry, Some(&tools)); assert_eq!(routing.len(), 1); assert_eq!(routing[0].builtin_type, BuiltinToolType::WebSearchPreview); @@ -492,13 +520,13 @@ mod tests { #[tokio::test] async fn test_collect_builtin_routing_no_configured_server() { - let orchestrator = create_test_orchestrator_no_builtin().await; + let (orchestrator, format_registry) = create_test_orchestrator_no_builtin().await; let tools = vec![ResponseTool::WebSearchPreview( WebSearchPreviewTool::default(), )]; - let routing = collect_builtin_routing(&orchestrator, Some(&tools)); + let routing = collect_builtin_routing(&orchestrator, &format_registry, Some(&tools)); // No routing because no server configured for this built-in type assert!(routing.is_empty()); @@ -506,7 +534,7 @@ mod tests { #[tokio::test] async fn test_collect_builtin_routing_ignores_mcp_tools() { - let orchestrator = create_test_orchestrator_with_builtin().await; + let (orchestrator, format_registry) = create_test_orchestrator_with_builtin().await; let tools = vec![ResponseTool::Mcp(McpTool { server_url: Some("http://example.com/mcp".to_string()), @@ -520,7 +548,7 @@ mod tests { defer_loading: None, })]; - let routing = collect_builtin_routing(&orchestrator, Some(&tools)); + let routing = collect_builtin_routing(&orchestrator, &format_registry, Some(&tools)); // MCP tools are not built-in types, should be empty assert!(routing.is_empty()); @@ -528,7 +556,7 @@ mod tests { #[tokio::test] async fn test_collect_builtin_routing_ignores_function_tools() { - let orchestrator = create_test_orchestrator_with_builtin().await; + let (orchestrator, format_registry) = create_test_orchestrator_with_builtin().await; let tools = vec![ResponseTool::Function(FunctionTool { function: Function { @@ -539,7 +567,7 @@ mod tests { }, })]; - let routing = collect_builtin_routing(&orchestrator, Some(&tools)); + let routing = collect_builtin_routing(&orchestrator, &format_registry, Some(&tools)); // Function tools are not built-in types, should be empty assert!(routing.is_empty()); @@ -547,9 +575,9 @@ mod tests { #[tokio::test] async fn test_collect_builtin_routing_none_tools() { - let orchestrator = create_test_orchestrator_no_builtin().await; + let (orchestrator, format_registry) = create_test_orchestrator_no_builtin().await; - let routing = collect_builtin_routing(&orchestrator, None); + let routing = collect_builtin_routing(&orchestrator, &format_registry, None); assert!(routing.is_empty()); } @@ -561,7 +589,7 @@ mod tests { web_search_tools.insert( "web_search".to_string(), ToolConfig { - response_format: ResponseFormatConfig::WebSearchCall, + response_format: Some(ResponseFormatConfig::WebSearchCall), ..Default::default() }, ); @@ -570,7 +598,7 @@ mod tests { code_interp_tools.insert( "run_code".to_string(), ToolConfig { - response_format: ResponseFormatConfig::CodeInterpreterCall, + response_format: Some(ResponseFormatConfig::CodeInterpreterCall), ..Default::default() }, ); @@ -613,6 +641,10 @@ mod tests { policy: Default::default(), }; + let format_registry = FormatRegistry::new(); + for server in &config.servers { + format_registry.populate_from_server_config(server); + } let orchestrator = Arc::new(McpOrchestrator::new(config).await.unwrap()); let tools = vec![ @@ -620,7 +652,7 @@ mod tests { ResponseTool::CodeInterpreter(CodeInterpreterTool::default()), ]; - let routing = collect_builtin_routing(&orchestrator, Some(&tools)); + let routing = collect_builtin_routing(&orchestrator, &format_registry, Some(&tools)); assert_eq!(routing.len(), 2); @@ -656,7 +688,7 @@ mod tests { image_gen_tools.insert( "generate_image".to_string(), ToolConfig { - response_format: ResponseFormatConfig::ImageGenerationCall, + response_format: Some(ResponseFormatConfig::ImageGenerationCall), ..Default::default() }, ); @@ -683,11 +715,15 @@ mod tests { policy: Default::default(), }; + let format_registry = FormatRegistry::new(); + for server in &config.servers { + format_registry.populate_from_server_config(server); + } let orchestrator = Arc::new(McpOrchestrator::new(config).await.unwrap()); let tools = vec![ResponseTool::ImageGeneration(ImageGenerationTool::default())]; - let routing = collect_builtin_routing(&orchestrator, Some(&tools)); + let routing = collect_builtin_routing(&orchestrator, &format_registry, Some(&tools)); assert_eq!(routing.len(), 1); assert_eq!(routing[0].builtin_type, BuiltinToolType::ImageGeneration); @@ -706,14 +742,14 @@ mod tests { #[tokio::test] async fn test_ensure_request_mcp_client_with_builtin_routing() { // Create orchestrator with a built-in server configured - let orchestrator = create_test_orchestrator_with_builtin().await; + let (orchestrator, format_registry) = create_test_orchestrator_with_builtin().await; // Request has web_search_preview tool (no server_url, not MCP type) let tools = vec![ResponseTool::WebSearchPreview( WebSearchPreviewTool::default(), )]; - let result = ensure_request_mcp_client(&orchestrator, &tools).await; + let result = ensure_request_mcp_client(&orchestrator, &format_registry, &tools).await; // Should return Some because built-in routing is configured assert!(result.is_some()); @@ -729,14 +765,14 @@ mod tests { #[tokio::test] async fn test_ensure_request_mcp_client_no_builtin_routing() { // Create orchestrator WITHOUT built-in server configured - let orchestrator = create_test_orchestrator_no_builtin().await; + let (orchestrator, format_registry) = create_test_orchestrator_no_builtin().await; // Request has web_search_preview tool let tools = vec![ResponseTool::WebSearchPreview( WebSearchPreviewTool::default(), )]; - let result = ensure_request_mcp_client(&orchestrator, &tools).await; + let result = ensure_request_mcp_client(&orchestrator, &format_registry, &tools).await; // Should return None because no MCP or built-in routing is available assert!(result.is_none()); @@ -744,7 +780,7 @@ mod tests { #[tokio::test] async fn test_ensure_request_mcp_client_function_tools_only() { - let orchestrator = create_test_orchestrator_with_builtin().await; + let (orchestrator, format_registry) = create_test_orchestrator_with_builtin().await; // Request has only function tools (no MCP, no built-in) let tools = vec![ResponseTool::Function(FunctionTool { @@ -756,7 +792,7 @@ mod tests { }, })]; - let result = ensure_request_mcp_client(&orchestrator, &tools).await; + let result = ensure_request_mcp_client(&orchestrator, &format_registry, &tools).await; // Should return None - function tools don't need MCP processing assert!(result.is_none()); @@ -765,7 +801,7 @@ mod tests { #[tokio::test] async fn test_ensure_request_mcp_client_mixed_tools() { // Create orchestrator with built-in server - let orchestrator = create_test_orchestrator_with_builtin().await; + let (orchestrator, format_registry) = create_test_orchestrator_with_builtin().await; // Request has mixed tools: function + web_search_preview let tools = vec![ @@ -780,7 +816,7 @@ mod tests { ResponseTool::WebSearchPreview(WebSearchPreviewTool::default()), ]; - let result = ensure_request_mcp_client(&orchestrator, &tools).await; + let result = ensure_request_mcp_client(&orchestrator, &format_registry, &tools).await; // Should return Some because web_search_preview has built-in routing assert!(result.is_some()); @@ -866,7 +902,7 @@ mod tests { let mut args = json!({"prompt": "a cat", "user": null}); inject_user_into_hosted_args( &mut args, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, Some("user-123"), ); assert_eq!(args.get("user"), Some(&json!("user-123"))); @@ -877,7 +913,7 @@ mod tests { let mut args = json!({"prompt": "a cat"}); inject_user_into_hosted_args( &mut args, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, Some("user-123"), ); assert_eq!(args.get("user"), Some(&json!("user-123"))); @@ -891,7 +927,7 @@ mod tests { let mut args = json!({"prompt": "a cat", "user": "model-supplied"}); inject_user_into_hosted_args( &mut args, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, Some("request-level"), ); assert_eq!(args.get("user"), Some(&json!("model-supplied"))); @@ -902,7 +938,7 @@ mod tests { // Plain MCP function tools have caller-defined schemas; injecting // `user` could surprise tools that don't expect that key. let mut args = json!({"q": "weather"}); - inject_user_into_hosted_args(&mut args, &ResponseFormat::Passthrough, Some("user-123")); + inject_user_into_hosted_args(&mut args, ResponseFormat::Passthrough, Some("user-123")); assert!( !args.as_object().unwrap().contains_key("user"), "passthrough format must not receive an injected user key" @@ -912,18 +948,18 @@ mod tests { #[test] fn inject_user_empty_or_missing_user_is_noop() { let mut args_none = json!({"prompt": "x"}); - inject_user_into_hosted_args(&mut args_none, &ResponseFormat::WebSearchCall, None); + inject_user_into_hosted_args(&mut args_none, ResponseFormat::WebSearchCall, None); assert!(!args_none.as_object().unwrap().contains_key("user")); let mut args_empty = json!({"prompt": "x"}); - inject_user_into_hosted_args(&mut args_empty, &ResponseFormat::WebSearchCall, Some("")); + inject_user_into_hosted_args(&mut args_empty, ResponseFormat::WebSearchCall, Some("")); assert!(!args_empty.as_object().unwrap().contains_key("user")); } #[test] fn inject_user_non_object_args_is_noop() { let mut args = json!("not-an-object"); - inject_user_into_hosted_args(&mut args, &ResponseFormat::FileSearchCall, Some("u1")); + inject_user_into_hosted_args(&mut args, ResponseFormat::FileSearchCall, Some("u1")); assert_eq!(args, json!("not-an-object")); } @@ -937,7 +973,7 @@ mod tests { ResponseFormat::FileSearchCall, ] { let mut args = json!({}); - inject_user_into_hosted_args(&mut args, &format, Some("user-xyz")); + inject_user_into_hosted_args(&mut args, format, Some("user-xyz")); assert_eq!( args.get("user"), Some(&json!("user-xyz")), diff --git a/model_gateway/src/routers/common/mod.rs b/model_gateway/src/routers/common/mod.rs index ef17891ff..c1068bf10 100644 --- a/model_gateway/src/routers/common/mod.rs +++ b/model_gateway/src/routers/common/mod.rs @@ -22,6 +22,7 @@ pub mod background; pub mod header_utils; pub mod mcp_utils; +pub mod openai_bridge; pub mod persistence_utils; pub mod retry; pub mod worker_selection; diff --git a/model_gateway/src/routers/common/openai_bridge/format_descriptor.rs b/model_gateway/src/routers/common/openai_bridge/format_descriptor.rs new file mode 100644 index 000000000..e073a69d9 --- /dev/null +++ b/model_gateway/src/routers/common/openai_bridge/format_descriptor.rs @@ -0,0 +1,75 @@ +//! Per-format dispatch table — one row per `ResponseFormat`. +//! +//! Routers read fields off [`FormatDescriptor`] instead of matching on the +//! enum directly, so adding a new format is one row here and zero edits in +//! router code. + +use openai_protocol::event_types::{ + CodeInterpreterCallEvent, FileSearchCallEvent, ImageGenerationCallEvent, ItemType, McpEvent, + WebSearchCallEvent, +}; + +use super::ResponseFormat; + +#[derive(Debug, Clone, Copy)] +pub struct FormatDescriptor { + pub type_str: &'static str, + pub id_prefix: &'static str, + pub in_progress_event: &'static str, + /// `None` for formats with no intermediate phase (e.g. Passthrough). + pub searching_event: Option<&'static str>, + pub completed_event: &'static str, + pub streams_arguments: bool, + /// `None` for formats without a partial-image-style intermediate frame. + pub partial_image_event: Option<&'static str>, +} + +pub const fn descriptor(format: ResponseFormat) -> FormatDescriptor { + match format { + ResponseFormat::WebSearchCall => FormatDescriptor { + type_str: ItemType::WEB_SEARCH_CALL, + id_prefix: "ws_", + in_progress_event: WebSearchCallEvent::IN_PROGRESS, + searching_event: Some(WebSearchCallEvent::SEARCHING), + completed_event: WebSearchCallEvent::COMPLETED, + streams_arguments: false, + partial_image_event: None, + }, + ResponseFormat::CodeInterpreterCall => FormatDescriptor { + type_str: ItemType::CODE_INTERPRETER_CALL, + id_prefix: "ci_", + in_progress_event: CodeInterpreterCallEvent::IN_PROGRESS, + searching_event: Some(CodeInterpreterCallEvent::INTERPRETING), + completed_event: CodeInterpreterCallEvent::COMPLETED, + streams_arguments: false, + partial_image_event: None, + }, + ResponseFormat::FileSearchCall => FormatDescriptor { + type_str: ItemType::FILE_SEARCH_CALL, + id_prefix: "fs_", + in_progress_event: FileSearchCallEvent::IN_PROGRESS, + searching_event: Some(FileSearchCallEvent::SEARCHING), + completed_event: FileSearchCallEvent::COMPLETED, + streams_arguments: false, + partial_image_event: None, + }, + ResponseFormat::ImageGenerationCall => FormatDescriptor { + type_str: ItemType::IMAGE_GENERATION_CALL, + id_prefix: "ig_", + in_progress_event: ImageGenerationCallEvent::IN_PROGRESS, + searching_event: Some(ImageGenerationCallEvent::GENERATING), + completed_event: ImageGenerationCallEvent::COMPLETED, + streams_arguments: false, + partial_image_event: Some(ImageGenerationCallEvent::PARTIAL_IMAGE), + }, + ResponseFormat::Passthrough => FormatDescriptor { + type_str: ItemType::MCP_CALL, + id_prefix: "mcp_", + in_progress_event: McpEvent::CALL_IN_PROGRESS, + searching_event: None, + completed_event: McpEvent::CALL_COMPLETED, + streams_arguments: true, + partial_image_event: None, + }, + } +} diff --git a/model_gateway/src/routers/common/openai_bridge/format_registry.rs b/model_gateway/src/routers/common/openai_bridge/format_registry.rs new file mode 100644 index 000000000..7170e15fa --- /dev/null +++ b/model_gateway/src/routers/common/openai_bridge/format_registry.rs @@ -0,0 +1,271 @@ +//! Side-map of `QualifiedToolName → ResponseFormat`, populated when MCP +//! servers register and queried by router code at request time. + +use std::sync::Arc; + +use dashmap::DashMap; +use smg_mcp::{inventory::ALIAS_SERVER_KEY, McpServerConfig, QualifiedToolName}; + +use super::ResponseFormat; + +/// Resolve an exposed tool name's `ResponseFormat` via the session's name map +/// and the registry. Returns `Passthrough` for unknown tools. +/// +/// Lives next to `FormatRegistry` because it's a thin lookup helper that +/// composes the session's name map with `FormatRegistry::lookup`. Reuses the +/// `QualifiedToolName` returned by `qualified_name_for_exposed` rather than +/// rebuilding one, so we pay the two `Arc` allocations once instead of +/// twice per call. +pub fn lookup_tool_format( + session: &smg_mcp::McpToolSession<'_>, + registry: &FormatRegistry, + exposed_name: &str, +) -> ResponseFormat { + let Some(qn) = session.qualified_name_for_exposed(exposed_name) else { + return ResponseFormat::Passthrough; + }; + registry.lookup(&qn) +} + +#[derive(Default, Debug, Clone)] +pub struct FormatRegistry { + formats: Arc>, +} + +impl FormatRegistry { + pub fn new() -> Self { + Self::default() + } + + pub fn lookup(&self, qualified: &QualifiedToolName) -> ResponseFormat { + self.formats + .get(qualified) + .map(|r| *r.value()) + .unwrap_or(ResponseFormat::Passthrough) + } + + pub fn lookup_by_names(&self, server_key: &str, tool_name: &str) -> ResponseFormat { + self.lookup(&QualifiedToolName::new(server_key, tool_name)) + } + + fn insert(&self, qualified: QualifiedToolName, format: ResponseFormat) { + self.formats.insert(qualified, format); + } + + /// Populate from a server config: per-tool overrides + builtin defaults. + /// Safe to call repeatedly — entries for non-Passthrough formats are + /// overwritten. Downgrading a format back to `Passthrough` requires a + /// separate registry rebuild (no production caller mutates configs in + /// place today). + /// + /// Mirrors `McpOrchestrator::apply_tool_configs`: + /// - When a tool has an `alias`, the format is attached **only** to the + /// alias entry (under `("alias", alias_name)`), matching the orchestrator's + /// `register_alias` qualified-name shape. The underlying `(server, tool)` + /// stays at the `Passthrough` default so direct calls aren't transformed. + /// - When a tool has no alias but a non-default format, attach to + /// `(server, tool)` directly. + /// - When the per-tool stanza omits `response_format` entirely + /// (`None`), the builtin default still applies. This lets users add an + /// `alias` or `arg_mapping` to a builtin tool without disabling its + /// hosted-format wire shape. An explicit `Some(Passthrough)` *does* + /// block the builtin default — that is the documented escape hatch + /// for opting out of the hosted shape. + pub fn populate_from_server_config(&self, config: &McpServerConfig) { + if let Some(tools) = &config.tools { + for (tool_name, tool_config) in tools { + let Some(format_config) = tool_config.response_format else { + continue; + }; + let format: ResponseFormat = format_config.into(); + if format == ResponseFormat::Passthrough { + continue; + } + if let Some(alias) = &tool_config.alias { + self.insert(QualifiedToolName::new(ALIAS_SERVER_KEY, alias), format); + } else { + self.insert(QualifiedToolName::new(&config.name, tool_name), format); + } + } + } + + if let (Some(builtin_type), Some(tool_name)) = + (&config.builtin_type, &config.builtin_tool_name) + { + let has_explicit_format = config + .tools + .as_ref() + .and_then(|tools| tools.get(tool_name)) + .is_some_and(|cfg| cfg.response_format.is_some()); + if !has_explicit_format { + let format: ResponseFormat = builtin_type.response_format().into(); + self.insert(QualifiedToolName::new(&config.name, tool_name), format); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use smg_mcp::{ + BuiltinToolType, McpServerConfig, McpTransport, ResponseFormatConfig, ToolConfig, + }; + + use super::*; + + fn server(name: &str) -> McpServerConfig { + McpServerConfig { + name: name.to_string(), + transport: McpTransport::Streamable { + url: "http://x".to_string(), + token: None, + headers: HashMap::new(), + }, + proxy: None, + required: false, + tools: None, + builtin_type: None, + builtin_tool_name: None, + internal: false, + } + } + + #[test] + fn lookup_unknown_returns_passthrough() { + let r = FormatRegistry::new(); + assert_eq!( + r.lookup_by_names("any", "tool"), + ResponseFormat::Passthrough + ); + } + + #[test] + fn alias_format_stored_under_alias_server_key() { + // Mirrors orchestrator::register_alias which uses + // QualifiedToolName::new("alias", alias_name). + let mut tools = HashMap::new(); + tools.insert( + "brave_web_search".to_string(), + ToolConfig { + alias: Some("web_search".to_string()), + response_format: Some(ResponseFormatConfig::WebSearchCall), + arg_mapping: None, + }, + ); + let mut cfg = server("brave"); + cfg.tools = Some(tools); + + let r = FormatRegistry::new(); + r.populate_from_server_config(&cfg); + + assert_eq!( + r.lookup_by_names("alias", "web_search"), + ResponseFormat::WebSearchCall, + "alias entry must use the literal `alias` server_key prefix" + ); + assert_eq!( + r.lookup_by_names("brave", "brave_web_search"), + ResponseFormat::Passthrough, + "underlying tool entry must NOT receive the format when an alias exists" + ); + } + + #[test] + fn non_aliased_tool_stores_format_under_server_tool_pair() { + let mut tools = HashMap::new(); + tools.insert( + "search".to_string(), + ToolConfig { + alias: None, + response_format: Some(ResponseFormatConfig::WebSearchCall), + arg_mapping: None, + }, + ); + let mut cfg = server("brave"); + cfg.tools = Some(tools); + + let r = FormatRegistry::new(); + r.populate_from_server_config(&cfg); + + assert_eq!( + r.lookup_by_names("brave", "search"), + ResponseFormat::WebSearchCall + ); + } + + #[test] + fn builtin_default_applies_when_no_explicit_tool_config() { + let mut cfg = server("search"); + cfg.builtin_type = Some(BuiltinToolType::WebSearchPreview); + cfg.builtin_tool_name = Some("do_search".to_string()); + + let r = FormatRegistry::new(); + r.populate_from_server_config(&cfg); + + assert_eq!( + r.lookup_by_names("search", "do_search"), + ResponseFormat::WebSearchCall + ); + } + + #[test] + fn explicit_per_tool_override_wins_over_builtin_default() { + let mut tools = HashMap::new(); + tools.insert( + "do_search".to_string(), + ToolConfig { + alias: None, + // Explicit override differs from the builtin default. + response_format: Some(ResponseFormatConfig::Passthrough), + arg_mapping: None, + }, + ); + let mut cfg = server("search"); + cfg.tools = Some(tools); + cfg.builtin_type = Some(BuiltinToolType::WebSearchPreview); + cfg.builtin_tool_name = Some("do_search".to_string()); + + let r = FormatRegistry::new(); + r.populate_from_server_config(&cfg); + + // Explicit Some(Passthrough) override means "no entry inserted" AND + // the builtin default is NOT applied on top. + assert_eq!( + r.lookup_by_names("search", "do_search"), + ResponseFormat::Passthrough + ); + } + + #[test] + fn alias_only_stanza_preserves_builtin_default() { + // Regression: a per-tool stanza that only aliases a builtin tool + // (or only sets arg_mapping) used to suppress the builtin default, + // collapsing the hosted format to plain mcp_call. With + // `response_format: None` meaning "inherit context", the builtin + // default must still apply. + let mut tools = HashMap::new(); + tools.insert( + "do_search".to_string(), + ToolConfig { + alias: Some("web_search".to_string()), + response_format: None, + arg_mapping: None, + }, + ); + let mut cfg = server("search"); + cfg.tools = Some(tools); + cfg.builtin_type = Some(BuiltinToolType::WebSearchPreview); + cfg.builtin_tool_name = Some("do_search".to_string()); + + let r = FormatRegistry::new(); + r.populate_from_server_config(&cfg); + + assert_eq!( + r.lookup_by_names("search", "do_search"), + ResponseFormat::WebSearchCall, + "alias-only stanza must not disable the builtin's hosted format" + ); + } +} diff --git a/model_gateway/src/routers/common/openai_bridge/mod.rs b/model_gateway/src/routers/common/openai_bridge/mod.rs new file mode 100644 index 000000000..7ec13c1b2 --- /dev/null +++ b/model_gateway/src/routers/common/openai_bridge/mod.rs @@ -0,0 +1,26 @@ +//! OpenAI Responses API bridge. +//! +//! Conversion logic between MCP protocol types and OpenAI Responses API +//! shapes. Gateway-internal — `smg-mcp` does not depend on OpenAI vocabulary. + +pub mod format_descriptor; +pub mod format_registry; +pub mod overrides; +pub mod response_format; +pub mod tool_descriptors; +pub mod transformer; + +pub use format_descriptor::{descriptor, FormatDescriptor}; +pub use format_registry::{lookup_tool_format, FormatRegistry}; +pub use overrides::{apply_hosted_tool_overrides, extract_hosted_tool_overrides}; +pub use response_format::ResponseFormat; +pub use tool_descriptors::{ + build_mcp_tool_infos, chat_function_tools, configure_response_tools_approval, + function_tools_json, inject_client_visible_mcp_output_items, mcp_list_tools_item, + mcp_list_tools_json, response_tools, should_hide_output_item_json, should_hide_tool_json, +}; +pub use transformer::{ + compact_image_generation_output, compact_image_generation_outputs_json, + extract_embedded_openai_responses, mcp_response_item_id, transform_tool_output, + ResponseTransformer, +}; diff --git a/crates/mcp/src/transform/overrides.rs b/model_gateway/src/routers/common/openai_bridge/overrides.rs similarity index 98% rename from crates/mcp/src/transform/overrides.rs rename to model_gateway/src/routers/common/openai_bridge/overrides.rs index 030f0fce6..d56c79d33 100644 --- a/crates/mcp/src/transform/overrides.rs +++ b/model_gateway/src/routers/common/openai_bridge/overrides.rs @@ -18,7 +18,7 @@ //! //! These helpers are intentionally: //! - **pure**: no session state, no I/O; testable in isolation. -//! - **symmetric** with [`crate::transform::transformer`], which handles the +//! - **symmetric** with [`super::transformer`], which handles the //! inverse direction (MCP result → `ResponseOutputItem`). //! //! # Contract @@ -35,8 +35,7 @@ use openai_protocol::responses::ResponseTool; use serde_json::Value; - -use crate::core::config::BuiltinToolType; +use smg_mcp::BuiltinToolType; /// Extract caller-declared configuration for the given hosted-tool kind from /// the request's `tools` declarations. diff --git a/crates/mcp/src/transform/types.rs b/model_gateway/src/routers/common/openai_bridge/response_format.rs similarity index 70% rename from crates/mcp/src/transform/types.rs rename to model_gateway/src/routers/common/openai_bridge/response_format.rs index f6379410e..80dd6422e 100644 --- a/crates/mcp/src/transform/types.rs +++ b/model_gateway/src/routers/common/openai_bridge/response_format.rs @@ -1,35 +1,28 @@ -//! Response transformation types. +//! `ResponseFormat` enum. use serde::{Deserialize, Serialize}; +use smg_mcp::{BuiltinToolType, ResponseFormatConfig}; -use crate::core::config::{BuiltinToolType, ResponseFormatConfig}; - -/// Format for transforming MCP responses to API-specific formats. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +/// Format for transforming MCP responses to OpenAI-shaped output items. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "snake_case")] pub enum ResponseFormat { - /// Pass through MCP result unchanged as mcp_call output + /// Pass through MCP result unchanged as `mcp_call` output. #[default] Passthrough, - /// Transform to OpenAI web_search_call format + /// Transform to OpenAI `web_search_call` format. WebSearchCall, - /// Transform to OpenAI code_interpreter_call format + /// Transform to OpenAI `code_interpreter_call` format. CodeInterpreterCall, - /// Transform to OpenAI file_search_call format + /// Transform to OpenAI `file_search_call` format. FileSearchCall, - /// Transform to OpenAI image_generation_call format + /// Transform to OpenAI `image_generation_call` format. ImageGenerationCall, } impl ResponseFormat { - /// Inverse of [`BuiltinToolType::response_format`]: returns the hosted-tool - /// kind this response format corresponds to, or `None` for the non-hosted - /// `Passthrough` format. - /// - /// Router dispatch paths use this to look up caller-declared overrides - /// for the current tool's hosted-tool kind without threading - /// [`BuiltinToolType`] separately through session bindings. - pub fn to_builtin_tool_type(&self) -> Option { + /// Inverse of [`smg_mcp::BuiltinToolType::response_format`]; `None` for `Passthrough`. + pub const fn to_builtin_tool_type(self) -> Option { match self { ResponseFormat::Passthrough => None, ResponseFormat::WebSearchCall => Some(BuiltinToolType::WebSearchPreview), @@ -71,11 +64,9 @@ mod tests { "\"image_generation_call\"", ), ]; - for (format, expected) in formats { let serialized = serde_json::to_string(&format).unwrap(); assert_eq!(serialized, expected); - let deserialized: ResponseFormat = serde_json::from_str(&serialized).unwrap(); assert_eq!(deserialized, format); } @@ -88,8 +79,6 @@ mod tests { #[test] fn test_to_builtin_tool_type_round_trip() { - // Non-passthrough formats should round-trip cleanly through - // BuiltinToolType::response_format. let kinds = [ BuiltinToolType::WebSearchPreview, BuiltinToolType::CodeInterpreter, @@ -100,8 +89,6 @@ mod tests { let fmt: ResponseFormat = kind.response_format().into(); assert_eq!(fmt.to_builtin_tool_type(), Some(kind)); } - - // Passthrough is not a hosted-tool kind. assert_eq!(ResponseFormat::Passthrough.to_builtin_tool_type(), None); } } diff --git a/model_gateway/src/routers/common/openai_bridge/tool_descriptors.rs b/model_gateway/src/routers/common/openai_bridge/tool_descriptors.rs new file mode 100644 index 000000000..7d513bb87 --- /dev/null +++ b/model_gateway/src/routers/common/openai_bridge/tool_descriptors.rs @@ -0,0 +1,313 @@ +//! Builders for Responses/Chat tool payloads derived from MCP tool inventory. +//! +//! `ToolEntry` stores its JSON Schema as `Arc`. Downstream +//! protocol types require an owned `serde_json::Map` inside `Value::Object`, +//! so every builder deep-clones the schema once per tool per call. Schema +//! maps are typically small; if this ever profiles hot, cache the +//! materialised `Value::Object` alongside the `Arc` on `ToolEntry`. + +use openai_protocol::{ + common::{Function, Tool}, + responses::{ + generate_id, FunctionTool, McpAllowedTools, McpToolInfo, RequireApproval, + RequireApprovalMode, ResponseOutputItem, ResponseTool, + }, +}; +use serde_json::{json, Value}; +use smg_mcp::{ApprovalMode, McpToolSession}; + +#[inline] +fn schema_to_value(schema: &serde_json::Map) -> Value { + Value::Object(schema.clone()) +} + +/// `(exposed_name, description, &schema)` triples for every tool exposed by +/// `session`. Centralizes name-resolution (alias / disambiguation) so the +/// per-protocol builders below stay one-liners. +fn exposed_tool_fields<'a>( + session: &'a McpToolSession<'a>, +) -> impl Iterator, &'a serde_json::Map)> + 'a { + let exposed = session.exposed_name_by_qualified(); + session.mcp_tools().iter().map(move |entry| { + let name = exposed + .get(&entry.qualified_name) + .map(String::as_str) + .unwrap_or_else(|| entry.tool_name()); + let description = entry.tool.description.as_deref(); + (name, description, &*entry.tool.input_schema) + }) +} + +/// Function-tool JSON payloads (`{"type": "function", ...}`) for upstream model calls. +pub fn function_tools_json(session: &McpToolSession<'_>) -> Vec { + exposed_tool_fields(session) + .map(|(name, description, parameters)| { + json!({ + "type": "function", + "name": name, + "description": description, + "parameters": schema_to_value(parameters) + }) + }) + .collect() +} + +/// Chat API function tools. +pub fn chat_function_tools(session: &McpToolSession<'_>) -> Vec { + exposed_tool_fields(session) + .map(|(name, description, parameters)| Tool { + tool_type: "function".to_string(), + function: Function { + name: name.to_string(), + description: description.map(str::to_string), + parameters: schema_to_value(parameters), + strict: None, + }, + }) + .collect() +} + +/// Responses API function tools. MCP tools surface to the model as +/// `{"type": "function", ...}` entries. +pub fn response_tools(session: &McpToolSession<'_>) -> Vec { + exposed_tool_fields(session) + .map(|(name, description, parameters)| { + ResponseTool::Function(FunctionTool { + function: Function { + name: name.to_string(), + description: description.map(str::to_string), + parameters: schema_to_value(parameters), + strict: None, + }, + }) + }) + .collect() +} + +/// `McpToolInfo` records used inside `mcp_list_tools` output items. +pub fn build_mcp_tool_infos(entries: &[smg_mcp::ToolEntry]) -> Vec { + entries + .iter() + .map(|entry| McpToolInfo { + name: entry.tool_name().to_string(), + description: entry.tool.description.as_ref().map(|d| d.to_string()), + input_schema: schema_to_value(&entry.tool.input_schema), + annotations: entry + .tool + .annotations + .as_ref() + .and_then(|a| serde_json::to_value(a).ok()), + }) + .collect() +} + +/// Typed `mcp_list_tools` output item for one server's exposed tools. +pub fn mcp_list_tools_item( + session: &McpToolSession<'_>, + server_label: &str, + server_key: &str, +) -> ResponseOutputItem { + let tools = session.list_tools_for_server(server_key); + ResponseOutputItem::McpListTools { + id: generate_id("mcpl"), + server_label: server_label.to_string(), + tools: build_mcp_tool_infos(&tools), + error: None, + } +} + +/// JSON form of `mcp_list_tools_item`. Falls back to a minimal stub if the +/// typed item fails to serialise (should be infallible for well-formed input). +pub fn mcp_list_tools_json( + session: &McpToolSession<'_>, + server_label: &str, + server_key: &str, +) -> Value { + serde_json::to_value(mcp_list_tools_item(session, server_label, server_key)).unwrap_or_else( + |_| json!({ "type": "mcp_list_tools", "server_label": server_label, "tools": [] }), + ) +} + +/// Inject only client-visible MCP metadata and call items into a response output array. +/// +/// Visibility policy: +/// - Hide builtin `mcp_list_tools` (builtin tools surface under their own type). +/// - Hide internal non-builtin `mcp_list_tools`. +/// - Hide internal non-builtin passthrough `mcp_call`/`mcp_approval_request`. +/// - Keep builtin-routed call items visible. +/// - Keep user-defined function calls visible even on name collisions. +pub fn inject_client_visible_mcp_output_items( + session: &McpToolSession<'_>, + output: &mut Vec, + tool_call_items: Vec, + user_function_names: &std::collections::HashSet, +) { + let existing = std::mem::take(output); + let servers = session.mcp_servers(); + output.reserve(servers.len() + tool_call_items.len() + existing.len()); + + for binding in servers { + if !session.is_internal_non_builtin_server_label(&binding.label) { + output.push(mcp_list_tools_item( + session, + &binding.label, + &binding.server_key, + )); + } + } + + for item in tool_call_items { + if is_client_visible_output_item(session, &item, user_function_names) { + output.push(item); + } + } + + for item in existing { + if is_client_visible_output_item(session, &item, user_function_names) { + output.push(item); + } + } +} + +/// Apply request-time approval configuration to exposed tools in `session`. +/// +/// Parses `ResponseTool::Mcp::require_approval`/`allowed_tools` and forwards +/// the resolved approval mode + scoping to `session.set_approval_mode`. +/// +/// `McpAllowedTools` projection (T11): +/// - `None` or `Filter { None, None }` → no name constraint (every binding +/// on the server inherits the explicit approval mode). +/// - `List(names)` / `Filter { tool_names: Some(v), .. }` → constrain by +/// explicit names. +/// - `Filter { tool_names: None, read_only: Some(_) }` → `None`. The +/// `readOnlyHint`-based filter is unimplemented; the safe direction for +/// approval scoping is "over-gate" — return `None` so the requested mode +/// applies to every binding. +pub fn configure_response_tools_approval(session: &mut McpToolSession<'_>, tools: &[ResponseTool]) { + for tool in tools { + let ResponseTool::Mcp(mcp_tool) = tool else { + continue; + }; + + let approval_mode = match mcp_tool.require_approval.as_ref() { + Some(RequireApproval::Mode(RequireApprovalMode::Always)) => ApprovalMode::Interactive, + _ => ApprovalMode::PolicyOnly, + }; + + if approval_mode == ApprovalMode::PolicyOnly { + continue; + } + + let allowed_tool_names: Option<&[String]> = + mcp_tool.allowed_tools.as_ref().and_then(|at| match at { + McpAllowedTools::List(names) => Some(names.as_slice()), + McpAllowedTools::Filter(filter) => filter.tool_names.as_deref(), + }); + session.set_approval_mode(&mcp_tool.server_label, allowed_tool_names, approval_mode); + } +} + +/// True when a JSON tool entry should be hidden from client-facing responses. +/// Used by OpenAI non-streaming response normalization (tools handled as +/// `serde_json::Value` payloads). +pub fn should_hide_tool_json( + session: &McpToolSession<'_>, + tool: &Value, + user_function_names: &std::collections::HashSet, +) -> bool { + match tool.get("type").and_then(|v| v.as_str()) { + Some("function") => function_tool_name_json(tool) + .is_some_and(|name| session.should_hide_function_call_like(name, user_function_names)), + // MCP tool entries are keyed by server metadata; function-name + // collision handling does not apply. + Some("mcp") => tool + .get("server_label") + .and_then(|v| v.as_str()) + .is_some_and(|label| session.is_internal_non_builtin_server_label(label)), + _ => false, + } +} + +/// True when a JSON output item should be hidden from client-facing responses. +/// Mirrors `is_client_visible_output_item` for the non-streaming path that +/// operates on raw JSON instead of typed `ResponseOutputItem`s. +pub fn should_hide_output_item_json( + session: &McpToolSession<'_>, + item: &Value, + user_function_names: &std::collections::HashSet, +) -> bool { + match item.get("type").and_then(|v| v.as_str()) { + Some("mcp_list_tools") => item + .get("server_label") + .and_then(|v| v.as_str()) + .is_some_and(|label| { + session.is_builtin_server_label(label) + || session.is_internal_non_builtin_server_label(label) + }), + Some("mcp_call") | Some("mcp_approval_request") => { + let matches_internal = item + .get("server_label") + .and_then(|v| v.as_str()) + .is_some_and(|label| session.is_internal_non_builtin_server_label(label)); + match item.get("name").and_then(|v| v.as_str()) { + Some(name) => { + session.should_hide_mcp_call_like_by_server_flag(name, matches_internal) + } + None => matches_internal, + } + } + Some("function_call") | Some("function_tool_call") => item + .get("name") + .and_then(|v| v.as_str()) + .is_some_and(|name| session.should_hide_function_call_like(name, user_function_names)), + _ => false, + } +} + +fn function_tool_name_json(tool: &Value) -> Option<&str> { + if tool.get("type").and_then(|v| v.as_str()) != Some("function") { + return None; + } + tool.get("name").and_then(|v| v.as_str()).or_else(|| { + tool.get("function") + .and_then(|f| f.get("name")) + .and_then(|v| v.as_str()) + }) +} + +fn is_client_visible_output_item( + session: &McpToolSession<'_>, + item: &ResponseOutputItem, + user_function_names: &std::collections::HashSet, +) -> bool { + match item { + ResponseOutputItem::McpListTools { server_label, .. } => { + !session.is_builtin_server_label(server_label) + && !session.is_internal_non_builtin_server_label(server_label) + } + ResponseOutputItem::McpCall { + server_label, name, .. + } + | ResponseOutputItem::McpApprovalRequest { + server_label, name, .. + } => !session.should_hide_mcp_call_like_by_label(name, server_label), + ResponseOutputItem::FunctionToolCall { name, .. } => { + !session.should_hide_function_call_like(name, user_function_names) + } + ResponseOutputItem::WebSearchCall { .. } + | ResponseOutputItem::CodeInterpreterCall { .. } + | ResponseOutputItem::FileSearchCall { .. } + | ResponseOutputItem::ImageGenerationCall { .. } + | ResponseOutputItem::ComputerCall { .. } + | ResponseOutputItem::ComputerCallOutput { .. } + | ResponseOutputItem::ShellCall { .. } + | ResponseOutputItem::ShellCallOutput { .. } + | ResponseOutputItem::ApplyPatchCall { .. } + | ResponseOutputItem::ApplyPatchCallOutput { .. } + | ResponseOutputItem::Message { .. } + | ResponseOutputItem::Reasoning { .. } + | ResponseOutputItem::Compaction { .. } + | ResponseOutputItem::LocalShellCall { .. } + | ResponseOutputItem::LocalShellCallOutput { .. } => true, + } +} diff --git a/crates/mcp/src/transform/transformer.rs b/model_gateway/src/routers/common/openai_bridge/transformer.rs similarity index 92% rename from crates/mcp/src/transform/transformer.rs rename to model_gateway/src/routers/common/openai_bridge/transformer.rs index dc44c1d0f..3ec3fe745 100644 --- a/crates/mcp/src/transform/transformer.rs +++ b/model_gateway/src/routers/common/openai_bridge/transformer.rs @@ -9,6 +9,29 @@ use tracing::warn; use super::ResponseFormat; +/// Transform a `ToolExecutionOutput` to a `ResponseOutputItem` using a +/// pre-resolved `ResponseFormat`. +/// +/// The format MUST be resolved via the session's exposed-name map (e.g. +/// [`super::lookup_tool_format`]). `output.tool_name` is the *invoked/exposed* +/// name after `McpToolSession::execute_tool_result` rewrites it, so a registry +/// lookup against `(output.server_key, output.tool_name)` would miss for +/// disambiguated names like `mcp__` and silently degrade to +/// `Passthrough`. +pub fn transform_tool_output( + output: &smg_mcp::ToolExecutionOutput, + response_format: ResponseFormat, +) -> ResponseOutputItem { + ResponseTransformer::transform( + &output.output, + response_format, + &output.call_id, + &output.server_label, + &output.tool_name, + &output.arguments_str, + ) +} + /// Normalize an MCP response item id source into an external `mcp_call.id`. /// /// The input may be an upstream output item id (`fc_*`), an internal call id @@ -81,7 +104,7 @@ impl ResponseTransformer { /// Returns a `ResponseOutputItem` from the protocols crate. pub fn transform( result: &serde_json::Value, - format: &ResponseFormat, + format: ResponseFormat, tool_call_id: &str, server_label: &str, tool_name: &str, @@ -600,15 +623,31 @@ impl ResponseTransformer { /// For any non-image item this is a no-op. pub fn compact_image_generation_output(item: &mut ResponseOutputItem) { if let ResponseOutputItem::ImageGenerationCall { result, .. } = item { - // `result.clear()` zeros the length but keeps the heap buffer - // allocated — for base64 image bytes that can be several MB of - // wasted capacity per stored item, defeating the compaction goal. - // Replace with a fresh empty string to actually free the backing - // buffer. + // Drop the heap buffer; `clear()` would retain (multi-MB) capacity. *result = String::new(); } } +/// Compact every `image_generation_call` item inside `outputs` in-place. JSON +/// counterpart to [`compact_image_generation_output`] used by the storage +/// layer, which works with `serde_json::Value` arrays rather than typed +/// `ResponseOutputItem`s. Non-image items are untouched. +pub fn compact_image_generation_outputs_json(outputs: &mut [serde_json::Value]) { + for item in outputs { + let Some(obj) = item.as_object_mut() else { + continue; + }; + if obj.get("type").and_then(|v| v.as_str()) != Some("image_generation_call") { + continue; + } + // Remove rather than blank: the input/replay shape models `result` as + // `Option` with `skip_serializing_if`, so the id-only multi-turn + // reference form requires field absence — an empty string would still + // surface on the wire and diverge from the documented replay payload. + obj.remove("result"); + } +} + fn parse_text_block_payload(item: &serde_json::Value) -> Option { let Some(obj) = item.as_object() else { warn!("Expected MCP result item to be an object"); @@ -644,7 +683,7 @@ mod tests { let result = json!({"key": "value"}); let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::Passthrough, + ResponseFormat::Passthrough, "call_test-1", "server", "tool", @@ -668,7 +707,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::Passthrough, + ResponseFormat::Passthrough, "test-2", "server", "tool", @@ -688,7 +727,7 @@ mod tests { let result = json!({"key": "value"}); let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::Passthrough, + ResponseFormat::Passthrough, "fc_abc123", "server", "tool", @@ -712,7 +751,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::Passthrough, + ResponseFormat::Passthrough, "test-3", "server", "tool", @@ -732,7 +771,7 @@ mod tests { let result = json!({"key": "value"}); let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::Passthrough, + ResponseFormat::Passthrough, "mcp_existing", "server", "tool", @@ -757,7 +796,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::Passthrough, + ResponseFormat::Passthrough, "test-4", "server", "tool", @@ -781,7 +820,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::Passthrough, + ResponseFormat::Passthrough, "test-4b", "server", "tool", @@ -807,7 +846,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::Passthrough, + ResponseFormat::Passthrough, "test-5", "server", "tool", @@ -836,7 +875,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::Passthrough, + ResponseFormat::Passthrough, "test-6", "server", "tool", @@ -870,7 +909,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::WebSearchCall, + ResponseFormat::WebSearchCall, "req-123", "server", "web_search", @@ -924,7 +963,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::WebSearchCall, + ResponseFormat::WebSearchCall, "req-embedded", "server", "web_search", @@ -961,7 +1000,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::WebSearchCall, + ResponseFormat::WebSearchCall, "req-legacy", "server", "web_search", @@ -997,7 +1036,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::CodeInterpreterCall, + ResponseFormat::CodeInterpreterCall, "req-456", "server", "code_interpreter", @@ -1034,7 +1073,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::FileSearchCall, + ResponseFormat::FileSearchCall, "req-789", "server", "file_search", @@ -1069,7 +1108,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-1", "server", "image_generation", @@ -1104,7 +1143,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-2", "server", "image_generation", @@ -1140,7 +1179,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-3", "server", "image_generation", @@ -1182,7 +1221,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-toplevel", "server", "image_generation", @@ -1216,7 +1255,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-alias", "server", "image_generation", @@ -1257,7 +1296,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-dist", "server", "image_generation", @@ -1286,7 +1325,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-4", "server", "image_generation", @@ -1329,7 +1368,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-meta", "server", "image_generation", @@ -1389,7 +1428,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-meta-textblock", "server", "image_generation", @@ -1427,7 +1466,7 @@ mod tests { let transformed = ResponseTransformer::transform( &result, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, "req-img-nometa", "server", "image_generation", @@ -1507,4 +1546,37 @@ mod tests { _ => panic!("Expected WebSearchCall"), } } + + #[test] + fn test_compact_image_generation_outputs_json_strips_base64() { + let mut outputs = vec![ + serde_json::json!({ + "type": "image_generation_call", + "id": "ig_1", + "result": "AAAA_LONG_BASE64", + "revised_prompt": "cat", + "status": "completed" + }), + // Untouched: not an image_generation_call. + serde_json::json!({ + "type": "mcp_call", + "id": "mcp_1", + "output": "raw text", + }), + // Image item with no `result` field: must not panic. + serde_json::json!({ + "type": "image_generation_call", + "id": "ig_2", + "status": "in_progress" + }), + ]; + + compact_image_generation_outputs_json(&mut outputs); + + assert!(outputs[0].get("result").is_none()); + assert_eq!(outputs[0]["revised_prompt"], serde_json::json!("cat")); + assert_eq!(outputs[0]["status"], serde_json::json!("completed")); + assert_eq!(outputs[1]["output"], serde_json::json!("raw text")); + assert!(outputs[2].get("result").is_none()); + } } diff --git a/model_gateway/src/routers/common/persistence_utils.rs b/model_gateway/src/routers/common/persistence_utils.rs index 40e4904cc..98010bd8e 100644 --- a/model_gateway/src/routers/common/persistence_utils.rs +++ b/model_gateway/src/routers/common/persistence_utils.rs @@ -15,6 +15,8 @@ use smg_data_connector::{ }; use tracing::{debug, info, warn}; +use super::openai_bridge; + // ============================================================================ // Constants // ============================================================================ @@ -171,9 +173,13 @@ fn get_string(json: &Value, key: &str) -> Option { json.get(key).and_then(|v| v.as_str()).map(String::from) } -/// Build a StoredResponse from response JSON and original request +/// Build a StoredResponse from response JSON and original request. +/// +/// Takes `response_json` by value so it can be moved into `raw_response` +/// without an extra clone — the persistence path passes a freshly-compacted +/// JSON object that has no other live references. pub fn build_stored_response( - response_json: &Value, + response_json: Value, original_body: &ResponsesRequest, ) -> StoredResponse { let mut stored = StoredResponse::new(None); @@ -181,7 +187,8 @@ pub fn build_stored_response( // Initialize empty array - will be populated by persist_conversation_items stored.input = Value::Array(vec![]); - stored.model = get_string(response_json, "model").or_else(|| Some(original_body.model.clone())); + stored.model = + get_string(&response_json, "model").or_else(|| Some(original_body.model.clone())); stored.safety_identifier.clone_from(&original_body.user); // `StoredResponse.conversation_id` is `Option`; flatten the @@ -191,7 +198,7 @@ pub fn build_stored_response( .as_ref() .map(|c| c.as_id().to_string()); - stored.previous_response_id = get_string(response_json, "previous_response_id") + stored.previous_response_id = get_string(&response_json, "previous_response_id") .map(|s| ResponseId::from(s.as_str())) .or_else(|| { original_body @@ -200,11 +207,11 @@ pub fn build_stored_response( .map(ResponseId::from) }); - if let Some(id_str) = get_string(response_json, "id") { + if let Some(id_str) = get_string(&response_json, "id") { stored.id = ResponseId::from(id_str.as_str()); } - stored.raw_response = response_json.clone(); + stored.raw_response = response_json; stored } @@ -288,6 +295,13 @@ fn extract_input_items(input: &ResponseInput) -> Result, String> { } } + // Strip image_generation_call.result base64 from + // historical replayed items before persistence — + // a no-op for non-image item types. + openai_bridge::compact_image_generation_outputs_json( + std::slice::from_mut(&mut value), + ); + Ok(value) } } @@ -415,24 +429,40 @@ async fn persist_conversation_items_inner( return Ok(()); } - // Extract response ID + // Clone once, then strip multi-MB base64 from any `image_generation_call` + // items before they reach storage. Replay only references images by `id`, + // so storing the bytes per turn would balloon the response chain. + let mut response_json = response_json.clone(); + if let Some(outputs) = response_json + .get_mut("output") + .and_then(|v| v.as_array_mut()) + { + openai_bridge::compact_image_generation_outputs_json(outputs); + } + + // Extract response ID and clone out the (already-compacted) output array + // before handing the rest of `response_json` to `build_stored_response`, + // which moves the value into `raw_response` without re-cloning the full + // response tree. let response_id_str = response_json .get("id") .and_then(|v| v.as_str()) - .ok_or_else(|| "Response missing id field".to_string())?; - let response_id = ResponseId::from(response_id_str); + .ok_or_else(|| "Response missing id field".to_string())? + .to_string(); + let response_id = ResponseId::from(response_id_str.as_str()); // Parse and normalize input items from request let input_items = extract_input_items(&original_body.input)?; - // Parse output items from response + // Cloning the array here is shallow against compacted items — base64 + // image payloads have already been stripped above. let output_items = response_json .get("output") .and_then(|v| v.as_array()) .cloned() .ok_or_else(|| "No output array in response".to_string())?; - // Build and store response + // Build and store response (consumes `response_json`). let mut stored_response = build_stored_response(response_json, original_body); stored_response.id = response_id.clone(); stored_response.input = Value::Array(input_items.clone()); @@ -464,7 +494,7 @@ async fn persist_conversation_items_inner( &conv_id, &input_items, &output_items, - response_id_str, + &response_id_str, ) .await?; info!( diff --git a/model_gateway/src/routers/conversations/handlers.rs b/model_gateway/src/routers/conversations/handlers.rs index 363a6efec..004a64eaa 100644 --- a/model_gateway/src/routers/conversations/handlers.rs +++ b/model_gateway/src/routers/conversations/handlers.rs @@ -15,7 +15,10 @@ use smg_data_connector::{ }; use tracing::info; -use crate::{memory::MemoryExecutionContext, routers::common::persistence_utils::item_to_json}; +use crate::{ + memory::MemoryExecutionContext, + routers::common::{openai_bridge, persistence_utils::item_to_json}, +}; // ============================================================================ // Constants @@ -622,7 +625,12 @@ fn parse_item_from_value( let content = if item_type == "message" || item_type == "reasoning" { item_val.get("content").cloned().unwrap_or(json!([])) } else { - item_val.clone() + // Strip image_generation_call.result base64 before storage. The + // compactor is a no-op for non-image item types, so it's safe to + // apply unconditionally on the non-message branch. + let mut content = item_val.clone(); + openai_bridge::compact_image_generation_outputs_json(std::slice::from_mut(&mut content)); + content }; Ok(( diff --git a/model_gateway/src/routers/gemini/context.rs b/model_gateway/src/routers/gemini/context.rs index fea47c7c9..8e5447dec 100644 --- a/model_gateway/src/routers/gemini/context.rs +++ b/model_gateway/src/routers/gemini/context.rs @@ -14,6 +14,7 @@ use smg_mcp::McpOrchestrator; use super::state::RequestState; use crate::{ middleware::TenantRequestMeta, + routers::common::openai_bridge::FormatRegistry, worker::{Worker, WorkerRegistry}, }; @@ -38,6 +39,9 @@ pub(crate) struct SharedComponents { #[expect(dead_code, reason = "MCP tool integration is Phase 2")] pub mcp_orchestrator: Arc, + #[expect(dead_code, reason = "MCP tool integration is Phase 2")] + pub mcp_format_registry: FormatRegistry, + /// Per-request timeout from router config. pub request_timeout: Duration, } diff --git a/model_gateway/src/routers/gemini/router.rs b/model_gateway/src/routers/gemini/router.rs index 56b50f448..accb85056 100644 --- a/model_gateway/src/routers/gemini/router.rs +++ b/model_gateway/src/routers/gemini/router.rs @@ -46,6 +46,7 @@ impl GeminiRouter { client: ctx.client.clone(), worker_registry: ctx.worker_registry.clone(), mcp_orchestrator, + mcp_format_registry: ctx.mcp_format_registry.clone(), request_timeout, }); let retry_config = ctx.router_config.effective_retry_config(); diff --git a/model_gateway/src/routers/grpc/common/responses/context.rs b/model_gateway/src/routers/grpc/common/responses/context.rs index d5835a1dc..df94023ae 100644 --- a/model_gateway/src/routers/grpc/common/responses/context.rs +++ b/model_gateway/src/routers/grpc/common/responses/context.rs @@ -10,7 +10,10 @@ use smg_data_connector::{ }; use smg_mcp::McpOrchestrator; -use crate::routers::grpc::{context::SharedComponents, pipeline::RequestPipeline}; +use crate::routers::{ + common::openai_bridge::FormatRegistry, + grpc::{context::SharedComponents, pipeline::RequestPipeline}, +}; /// Context for /v1/responses endpoint /// @@ -39,6 +42,8 @@ pub(crate) struct ResponsesContext { /// MCP orchestrator for tool support pub mcp_orchestrator: Arc, + pub mcp_format_registry: FormatRegistry, + /// Storage hook request context extracted from HTTP headers by middleware. pub request_context: Option, } @@ -57,6 +62,7 @@ impl ResponsesContext { conversation_item_storage: Arc, conversation_memory_writer: Arc, mcp_orchestrator: Arc, + mcp_format_registry: FormatRegistry, request_context: Option, ) -> Self { Self { @@ -67,6 +73,7 @@ impl ResponsesContext { conversation_item_storage, conversation_memory_writer, mcp_orchestrator, + mcp_format_registry, request_context, } } diff --git a/model_gateway/src/routers/grpc/common/responses/streaming.rs b/model_gateway/src/routers/grpc/common/responses/streaming.rs index 5cc119df8..00793e1d5 100644 --- a/model_gateway/src/routers/grpc/common/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/common/responses/streaming.rs @@ -7,22 +7,24 @@ use openai_protocol::{ chat::ChatCompletionStreamResponse, common::{Usage, UsageInfo}, event_types::{ - CodeInterpreterCallEvent, ContentPartEvent, FileSearchCallEvent, FunctionCallEvent, - ImageGenerationCallEvent, McpEvent, OutputItemEvent, OutputTextEvent, ResponseEvent, - WebSearchCallEvent, + ContentPartEvent, FunctionCallEvent, McpEvent, OutputItemEvent, OutputTextEvent, + ResponseEvent, }, responses::{ ResponseOutputItem, ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage, }, }; use serde_json::json; -use smg_mcp::{self as mcp, ResponseFormat}; +use smg_mcp::{self as mcp}; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::warn; use uuid::Uuid; -use crate::routers::grpc::harmony::responses::ToolResult; +use crate::routers::{ + common::openai_bridge::{self, descriptor, ResponseFormat}, + grpc::harmony::responses::ToolResult, +}; pub(crate) enum OutputItemType { Message, @@ -354,11 +356,11 @@ impl ResponseStreamEventEmitter { }) } - /// Convert tool entries to JSON values using the shared `build_mcp_tool_infos` bridge. + /// Convert tool entries to JSON values using the shared bridge builder. fn tool_entries_to_json( tools: &[mcp::ToolEntry], ) -> Result, serde_json::Error> { - mcp::build_mcp_tool_infos(tools) + openai_bridge::build_mcp_tool_infos(tools) .into_iter() .map(serde_json::to_value) .collect() @@ -465,54 +467,30 @@ impl ResponseStreamEventEmitter { }) } - /// Emit the appropriate in_progress event based on response format pub fn emit_tool_call_in_progress( &mut self, output_index: usize, item_id: &str, - response_format: &ResponseFormat, + response_format: ResponseFormat, ) -> serde_json::Value { - let event_type = match response_format { - ResponseFormat::WebSearchCall => WebSearchCallEvent::IN_PROGRESS, - ResponseFormat::CodeInterpreterCall => CodeInterpreterCallEvent::IN_PROGRESS, - ResponseFormat::FileSearchCall => FileSearchCallEvent::IN_PROGRESS, - ResponseFormat::ImageGenerationCall => ImageGenerationCallEvent::IN_PROGRESS, - ResponseFormat::Passthrough => McpEvent::CALL_IN_PROGRESS, - }; + let event_type = descriptor(response_format).in_progress_event; self.emit_tool_event(event_type, output_index, item_id) } - /// Emit the searching/interpreting/generating event for builtin tool calls (no-op for passthrough). - /// - /// For `image_generation_call` this emits the `generating` event. The - /// partial-image event is emitted separately via `emit_image_generation_partial_image` - /// because it carries additional payload (the partial b64 bytes) and is - /// optional per the `partial_images` request field. + /// Emit the searching/interpreting/generating event; `None` for formats + /// with no intermediate phase. pub fn emit_tool_call_searching( &mut self, output_index: usize, item_id: &str, - response_format: &ResponseFormat, + response_format: ResponseFormat, ) -> Option { - let event_type = match response_format { - ResponseFormat::WebSearchCall => WebSearchCallEvent::SEARCHING, - ResponseFormat::CodeInterpreterCall => CodeInterpreterCallEvent::INTERPRETING, - ResponseFormat::FileSearchCall => FileSearchCallEvent::SEARCHING, - ResponseFormat::ImageGenerationCall => ImageGenerationCallEvent::GENERATING, - ResponseFormat::Passthrough => return None, - }; + let event_type = descriptor(response_format).searching_event?; Some(self.emit_tool_event(event_type, output_index, item_id)) } - /// Emit a `response.image_generation_call.partial_image` event. - /// - /// Returns `None` when `response_format` is anything other than - /// [`ResponseFormat::ImageGenerationCall`], mirroring how - /// `emit_tool_call_searching` gates on format. The payload carries the - /// base64-encoded partial image bytes plus a 0-based partial image index. - /// - /// Per-router wiring is responsible for deciding when to call this and - /// how to source the partial-image bytes. + /// Emit a `response.image_generation_call.partial_image` event. Returns + /// `None` for formats with no partial-image frame. #[expect( dead_code, reason = "partial_image emission is wired by per-router integrations" @@ -521,15 +499,13 @@ impl ResponseStreamEventEmitter { &mut self, output_index: usize, item_id: &str, - response_format: &ResponseFormat, + response_format: ResponseFormat, partial_image_index: u32, partial_image_b64: &str, ) -> Option { - if !matches!(response_format, ResponseFormat::ImageGenerationCall) { - return None; - } + let event_type = descriptor(response_format).partial_image_event?; Some(json!({ - "type": ImageGenerationCallEvent::PARTIAL_IMAGE, + "type": event_type, "sequence_number": self.next_sequence(), "output_index": output_index, "item_id": item_id, @@ -538,40 +514,24 @@ impl ResponseStreamEventEmitter { })) } - /// Emit the appropriate completed event based on response format pub fn emit_tool_call_completed( &mut self, output_index: usize, item_id: &str, - response_format: &ResponseFormat, + response_format: ResponseFormat, ) -> serde_json::Value { - let event_type = match response_format { - ResponseFormat::WebSearchCall => WebSearchCallEvent::COMPLETED, - ResponseFormat::CodeInterpreterCall => CodeInterpreterCallEvent::COMPLETED, - ResponseFormat::FileSearchCall => FileSearchCallEvent::COMPLETED, - ResponseFormat::ImageGenerationCall => ImageGenerationCallEvent::COMPLETED, - ResponseFormat::Passthrough => McpEvent::CALL_COMPLETED, - }; + let event_type = descriptor(response_format).completed_event; self.emit_tool_event(event_type, output_index, item_id) } - // ======================================================================== - // Helper Methods for ResponseFormat - // ======================================================================== - - /// Get the type string for JSON based on response format. pub fn type_str_for_format(response_format: Option<&ResponseFormat>) -> &'static str { match response_format { - Some(ResponseFormat::WebSearchCall) => "web_search_call", - Some(ResponseFormat::CodeInterpreterCall) => "code_interpreter_call", - Some(ResponseFormat::FileSearchCall) => "file_search_call", - Some(ResponseFormat::ImageGenerationCall) => "image_generation_call", - Some(ResponseFormat::Passthrough) => "mcp_call", + Some(format) => descriptor(*format).type_str, None => "function_call", } } - /// Get the OutputItemType based on response format. + /// Map a `ResponseFormat` to the grpc-router-private `OutputItemType` enum. pub fn output_item_type_for_format(response_format: Option<&ResponseFormat>) -> OutputItemType { match response_format { Some(ResponseFormat::WebSearchCall) => OutputItemType::WebSearchCall, diff --git a/model_gateway/src/routers/grpc/common/responses/utils.rs b/model_gateway/src/routers/grpc/common/responses/utils.rs index 66b09b63d..336865050 100644 --- a/model_gateway/src/routers/grpc/common/responses/utils.rs +++ b/model_gateway/src/routers/grpc/common/responses/utils.rs @@ -18,7 +18,8 @@ use tracing::{debug, error, warn}; use crate::{ routers::{ common::{ - mcp_utils::ensure_request_mcp_client, persistence_utils::persist_conversation_items, + mcp_utils::ensure_request_mcp_client, openai_bridge, + persistence_utils::persist_conversation_items, }, error, }, @@ -34,6 +35,7 @@ use crate::{ /// Returns Ok((has_mcp_tools, mcp_servers)) on success. pub(crate) async fn ensure_mcp_connection( mcp_orchestrator: &Arc, + format_registry: &openai_bridge::FormatRegistry, tools: Option<&[ResponseTool]>, ) -> Result<(bool, Vec), Response> { // Check for explicit MCP tools (must error if connection fails) @@ -72,7 +74,7 @@ pub(crate) async fn ensure_mcp_connection( if let Some(tools) = tools { // TODO: Thread real request headers through the gRPC responses path if/when // gRPC MCP flows need the same forwarded-header preservation contract. - match ensure_request_mcp_client(mcp_orchestrator, tools).await { + match ensure_request_mcp_client(mcp_orchestrator, format_registry, tools).await { Some(mcp_servers) => { return Ok((true, mcp_servers)); } diff --git a/model_gateway/src/routers/grpc/harmony/responses/common.rs b/model_gateway/src/routers/grpc/harmony/responses/common.rs index 24770772c..16ac8e854 100644 --- a/model_gateway/src/routers/grpc/harmony/responses/common.rs +++ b/model_gateway/src/routers/grpc/harmony/responses/common.rs @@ -13,12 +13,16 @@ use openai_protocol::{ }; use serde_json::{from_value, to_string, Value}; use smg_data_connector::{ResponseId, ResponseStorageError}; -use smg_mcp::{McpToolSession, ResponseFormat}; +use smg_mcp::McpToolSession; use tracing::{debug, error, warn}; use uuid::Uuid; use super::execution::ToolResult; -use crate::routers::{error, grpc::common::responses::ResponsesContext}; +use crate::routers::{ + common::openai_bridge::{self, FormatRegistry, ResponseFormat}, + error, + grpc::common::responses::ResponsesContext, +}; /// Record of a single MCP tool call execution /// @@ -177,7 +181,8 @@ pub(super) fn inject_mcp_metadata( .map(|record| record.output_item.clone()) .collect(); - session.inject_client_visible_mcp_output_items( + openai_bridge::inject_client_visible_mcp_output_items( + session, &mut response.output, tool_output_items, user_function_names, @@ -325,15 +330,22 @@ pub(super) async fn load_previous_messages( pub(super) fn strip_image_generation_from_request_tools( request: &mut ResponsesRequest, session: &McpToolSession<'_>, + format_registry: &FormatRegistry, ) { - // Check whether any MCP tool in the session carries the - // `ImageGenerationCall` response format — this is the authoritative - // signal that an MCP server is routed for the hosted - // `image_generation` tool in this request. + // Strip only when the session exposes the literal `image_generation` + // name routed to an `ImageGenerationCall` format. Checking every MCP + // tool's format would also fire for unrelated custom tools that happen + // to share the image-generation output shape (e.g. a `thumbnailer`), + // and would then drop the real hosted `image_generation` tag even + // though the session has no dispatcher for it. let mcp_has_image_generation = session - .mcp_tools() - .iter() - .any(|entry| matches!(entry.response_format, ResponseFormat::ImageGenerationCall)); + .qualified_name_for_exposed("image_generation") + .is_some_and(|qn| { + matches!( + format_registry.lookup(&qn), + ResponseFormat::ImageGenerationCall + ) + }); if !mcp_has_image_generation { return; diff --git a/model_gateway/src/routers/grpc/harmony/responses/execution.rs b/model_gateway/src/routers/grpc/harmony/responses/execution.rs index a59cdf461..d4a6b0853 100644 --- a/model_gateway/src/routers/grpc/harmony/responses/execution.rs +++ b/model_gateway/src/routers/grpc/harmony/responses/execution.rs @@ -12,7 +12,10 @@ use tracing::{debug, error}; use super::common::McpCallTracking; use crate::{ observability::metrics::{metrics_labels, Metrics}, - routers::common::mcp_utils::prepare_hosted_dispatch_args, + routers::common::{ + mcp_utils::prepare_hosted_dispatch_args, + openai_bridge::{self, FormatRegistry}, + }, }; /// Tool execution result @@ -53,6 +56,7 @@ pub(crate) struct ToolResult { /// Vector of tool results (one per tool call) pub(super) async fn execute_mcp_tools( session: &McpToolSession<'_>, + format_registry: &FormatRegistry, tool_calls: &[ToolCall], tracking: &mut McpCallTracking, model_id: &str, @@ -90,8 +94,9 @@ pub(super) async fn execute_mcp_tools( json!({}) } }; - let response_format = session.tool_response_format(&tc.function.name); - prepare_hosted_dispatch_args(&mut args, &response_format, request_tools, request_user); + let response_format = + openai_bridge::lookup_tool_format(session, format_registry, &tc.function.name); + prepare_hosted_dispatch_args(&mut args, response_format, request_tools, request_user); ToolExecutionInput { call_id: tc.id.clone(), tool_name: tc.function.name.clone(), @@ -112,8 +117,9 @@ pub(super) async fn execute_mcp_tools( let results: Vec = outputs .into_iter() .map(|output| { - // Transform to correctly-typed ResponseOutputItem - let output_item = output.to_response_item(); + let response_format = + openai_bridge::lookup_tool_format(session, format_registry, &output.tool_name); + let output_item = openai_bridge::transform_tool_output(&output, response_format); // Record this call in tracking tracking.record_call(output_item.clone()); @@ -145,5 +151,5 @@ pub(super) async fn execute_mcp_tools( pub(crate) fn convert_mcp_tools_to_response_tools( session: &McpToolSession<'_>, ) -> Vec { - session.build_response_tools() + openai_bridge::response_tools(session) } diff --git a/model_gateway/src/routers/grpc/harmony/responses/non_streaming.rs b/model_gateway/src/routers/grpc/harmony/responses/non_streaming.rs index 3fd641141..180c8e600 100644 --- a/model_gateway/src/routers/grpc/harmony/responses/non_streaming.rs +++ b/model_gateway/src/routers/grpc/harmony/responses/non_streaming.rs @@ -63,8 +63,12 @@ pub(crate) async fn serve_harmony_responses( let current_request = load_previous_messages(ctx, request).await?; // Check MCP connection and get whether MCP tools are present - let (has_mcp_tools, mcp_servers) = - ensure_mcp_connection(&ctx.mcp_orchestrator, current_request.tools.as_deref()).await?; + let (has_mcp_tools, mcp_servers) = ensure_mcp_connection( + &ctx.mcp_orchestrator, + &ctx.mcp_format_registry, + current_request.tools.as_deref(), + ) + .await?; let response = if has_mcp_tools { execute_with_mcp_loop( @@ -140,7 +144,11 @@ async fn execute_with_mcp_loop( // advertises only the MCP-exposed function-tool name (which // `has_exposed_tool` actually recognizes for dispatch). See the // helper doc comment in `common.rs` for the full rationale. - strip_image_generation_from_request_tools(&mut current_request, &session); + strip_image_generation_from_request_tools( + &mut current_request, + &session, + &ctx.mcp_format_registry, + ); loop { iteration_count += 1; @@ -267,6 +275,7 @@ async fn execute_with_mcp_loop( } else { execute_mcp_tools( &session, + &ctx.mcp_format_registry, &mcp_tool_calls, &mut mcp_tracking, ¤t_request.model, diff --git a/model_gateway/src/routers/grpc/harmony/responses/streaming.rs b/model_gateway/src/routers/grpc/harmony/responses/streaming.rs index 496a348af..d81205444 100644 --- a/model_gateway/src/routers/grpc/harmony/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/harmony/responses/streaming.rs @@ -55,6 +55,7 @@ pub(crate) async fn serve_harmony_responses_stream( // Check MCP connection BEFORE starting stream and get whether MCP tools are present let (has_mcp_tools, mcp_servers) = match ensure_mcp_connection( &ctx.mcp_orchestrator, + &ctx.mcp_format_registry, current_request.tools.as_deref(), ) .await @@ -170,7 +171,11 @@ async fn execute_mcp_tool_loop_streaming( // advertises only the MCP-exposed function-tool name (which // `has_exposed_tool` actually recognizes for dispatch). See the // helper doc comment in `common.rs` for the full rationale. - strip_image_generation_from_request_tools(&mut current_request, &session); + strip_image_generation_from_request_tools( + &mut current_request, + &session, + &ctx.mcp_format_registry, + ); let mut mcp_tracking = McpCallTracking::new(); @@ -241,6 +246,7 @@ async fn execute_mcp_tool_loop_streaming( emitter, tx, Some(&session), + Some(&ctx.mcp_format_registry), ) .await { @@ -319,6 +325,7 @@ async fn execute_mcp_tool_loop_streaming( } else { match execute_mcp_tools( &session, + &ctx.mcp_format_registry, &mcp_tool_calls, &mut mcp_tracking, ¤t_request.model, @@ -455,6 +462,7 @@ async fn execute_without_mcp_streaming( emitter, tx, None, + None, ) .await { diff --git a/model_gateway/src/routers/grpc/harmony/streaming.rs b/model_gateway/src/routers/grpc/harmony/streaming.rs index 0515a7185..fef746643 100644 --- a/model_gateway/src/routers/grpc/harmony/streaming.rs +++ b/model_gateway/src/routers/grpc/harmony/streaming.rs @@ -20,7 +20,7 @@ use openai_protocol::{ }, }; use serde_json::json; -use smg_mcp::{McpToolSession, ResponseFormat, DEFAULT_SERVER_LABEL}; +use smg_mcp::{McpToolSession, DEFAULT_SERVER_LABEL}; use tokio::sync::mpsc; use tracing::{debug, error}; @@ -30,17 +30,22 @@ use super::{ }; use crate::{ observability::metrics::{metrics_labels, Metrics, StreamingMetricsParams}, - routers::grpc::{ - common::{ - response_formatting::CompletionTokenTracker, - responses::{ - build_sse_response, - streaming::{attach_mcp_server_label, OutputItemType, ResponseStreamEventEmitter}, + routers::{ + common::openai_bridge::{self, descriptor, FormatRegistry, ResponseFormat}, + grpc::{ + common::{ + response_formatting::CompletionTokenTracker, + responses::{ + build_sse_response, + streaming::{ + attach_mcp_server_label, OutputItemType, ResponseStreamEventEmitter, + }, + }, }, + context, + proto_wrapper::{ProtoResponseVariant, ProtoStream}, + utils, }, - context, - proto_wrapper::{ProtoResponseVariant, ProtoStream}, - utils, }, }; @@ -59,13 +64,9 @@ use crate::{ /// `None` (plain function tools) and `Some(Passthrough)` (MCP `mcp_call`) /// are the only formats that stream arguments through this router. fn streams_arguments(response_format: Option<&ResponseFormat>) -> bool { - match response_format { - None | Some(ResponseFormat::Passthrough) => true, - Some(ResponseFormat::WebSearchCall) - | Some(ResponseFormat::CodeInterpreterCall) - | Some(ResponseFormat::FileSearchCall) - | Some(ResponseFormat::ImageGenerationCall) => false, - } + response_format + .map(|f| descriptor(*f).streams_arguments) + .unwrap_or(true) } /// Processor for streaming Harmony responses @@ -499,15 +500,24 @@ impl HarmonyStreamingProcessor { emitter: &mut ResponseStreamEventEmitter, tx: &mpsc::UnboundedSender>, session: Option<&McpToolSession<'_>>, + format_registry: Option<&FormatRegistry>, ) -> Result { match execution_result { context::ExecutionResult::Single { stream } => { debug!("Processing Responses API single stream mode"); - Self::process_decode_stream(stream, emitter, tx, session, 0).await + Self::process_decode_stream(stream, emitter, tx, session, format_registry, 0).await } context::ExecutionResult::Dual { prefill, decode } => { debug!("Processing Responses API dual stream mode"); - Self::process_responses_dual_stream(prefill, *decode, emitter, tx, session).await + Self::process_responses_dual_stream( + prefill, + *decode, + emitter, + tx, + session, + format_registry, + ) + .await } context::ExecutionResult::Embedding { .. } => { Err("Embeddings not supported in Responses API streaming".to_string()) @@ -521,6 +531,7 @@ impl HarmonyStreamingProcessor { emitter: &mut ResponseStreamEventEmitter, tx: &mpsc::UnboundedSender>, session: Option<&McpToolSession<'_>>, + format_registry: Option<&FormatRegistry>, ) -> Result { // Phase 1: Drain prefill stream, collecting cached_tokens from Complete messages let mut prefill_cached_tokens_by_index: HashMap = HashMap::new(); @@ -534,9 +545,15 @@ impl HarmonyStreamingProcessor { let prefill_cached_tokens: u32 = prefill_cached_tokens_by_index.values().sum(); // Phase 2: Process decode stream - let result = - Self::process_decode_stream(decode_stream, emitter, tx, session, prefill_cached_tokens) - .await; + let result = Self::process_decode_stream( + decode_stream, + emitter, + tx, + session, + format_registry, + prefill_cached_tokens, + ) + .await; prefill_stream.mark_completed(); result @@ -548,6 +565,7 @@ impl HarmonyStreamingProcessor { emitter: &mut ResponseStreamEventEmitter, tx: &mpsc::UnboundedSender>, session: Option<&McpToolSession<'_>>, + format_registry: Option<&FormatRegistry>, prefill_cached_tokens: u32, ) -> Result { let mut parser = @@ -680,7 +698,9 @@ impl HarmonyStreamingProcessor { // Determine response_format based on MCP context. let response_format = session.and_then(|s| { if s.has_exposed_tool(tool_name) { - Some(s.tool_response_format(tool_name)) + format_registry.map(|reg| { + openai_bridge::lookup_tool_format(s, reg, tool_name) + }) } else { None } @@ -700,7 +720,7 @@ impl HarmonyStreamingProcessor { tool_call_tracking.insert( call_index, - (output_index, item_id.clone(), response_format.clone()), + (output_index, item_id.clone(), response_format), ); // Build output_item.added event @@ -726,7 +746,7 @@ impl HarmonyStreamingProcessor { emitter.send_event_best_effort(&event, tx); // Emit in_progress event for MCP tools - if let Some(ref fmt) = response_format { + if let Some(fmt) = response_format { let event = emitter.emit_tool_call_in_progress( output_index, &item_id, @@ -861,7 +881,7 @@ impl HarmonyStreamingProcessor { } // Emit completed event for MCP tools - if let Some(ref fmt) = response_format { + if let Some(fmt) = *response_format { let event = emitter.emit_tool_call_completed( *output_index, item_id, @@ -998,7 +1018,7 @@ impl HarmonyStreamingProcessor { } // Emit completed event for MCP tools - if let Some(ref fmt) = response_format { + if let Some(fmt) = *response_format { let event = emitter.emit_tool_call_completed(*output_index, item_id, fmt); emitter.send_event_best_effort(&event, tx); @@ -1113,7 +1133,7 @@ mod tests { /// the production classifier so a drift between the two is a separate /// failure (runtime assertion miss) from a missing variant (compile /// error here). - fn expected_streams_arguments(format: &ResponseFormat) -> bool { + fn expected_streams_arguments(format: ResponseFormat) -> bool { match format { ResponseFormat::Passthrough => true, ResponseFormat::WebSearchCall @@ -1144,28 +1164,28 @@ mod tests { streams_arguments(Some(&passthrough)), "mcp_call (Passthrough) should stream args", ); - assert!(expected_streams_arguments(&passthrough)); + assert!(expected_streams_arguments(passthrough)); // Hosted built-ins do *not* stream arguments — they surface // progress via structured `*.in_progress` / `*.searching` / // `*.generating` / `*.completed` events from the shared emitter. let web_search = ResponseFormat::WebSearchCall; assert!(!streams_arguments(Some(&web_search))); - assert!(!expected_streams_arguments(&web_search)); + assert!(!expected_streams_arguments(web_search)); let code_interpreter = ResponseFormat::CodeInterpreterCall; assert!(!streams_arguments(Some(&code_interpreter))); - assert!(!expected_streams_arguments(&code_interpreter)); + assert!(!expected_streams_arguments(code_interpreter)); let file_search = ResponseFormat::FileSearchCall; assert!(!streams_arguments(Some(&file_search))); - assert!(!expected_streams_arguments(&file_search)); + assert!(!expected_streams_arguments(file_search)); let image_generation = ResponseFormat::ImageGenerationCall; assert!( !streams_arguments(Some(&image_generation)), "image_generation_call must ride the structured-event path", ); - assert!(!expected_streams_arguments(&image_generation)); + assert!(!expected_streams_arguments(image_generation)); } } diff --git a/model_gateway/src/routers/grpc/regular/responses/common.rs b/model_gateway/src/routers/grpc/regular/responses/common.rs index a152ba934..11845714b 100644 --- a/model_gateway/src/routers/grpc/regular/responses/common.rs +++ b/model_gateway/src/routers/grpc/regular/responses/common.rs @@ -24,7 +24,8 @@ use tracing::{debug, warn}; use crate::{ middleware::TenantRequestMeta, routers::{ - common::persistence_utils::split_stored_message_content, error, + common::{openai_bridge, persistence_utils::split_stored_message_content}, + error, grpc::common::responses::ResponsesContext, }, }; @@ -154,7 +155,7 @@ pub(super) fn extract_all_tool_calls_from_chat( } pub(super) fn convert_mcp_tools_to_chat_tools(session: &McpToolSession<'_>) -> Vec { - session.build_chat_function_tools() + openai_bridge::chat_function_tools(session) } // ============================================================================ diff --git a/model_gateway/src/routers/grpc/regular/responses/handlers.rs b/model_gateway/src/routers/grpc/regular/responses/handlers.rs index bcad8eeb7..7de0ed3a3 100644 --- a/model_gateway/src/routers/grpc/regular/responses/handlers.rs +++ b/model_gateway/src/routers/grpc/regular/responses/handlers.rs @@ -118,11 +118,16 @@ async fn route_responses_streaming( }; // 2. Check MCP connection and get whether MCP tools are present - let (has_mcp_tools, mcp_servers) = - match ensure_mcp_connection(&ctx.mcp_orchestrator, request.tools.as_deref()).await { - Ok(result) => result, - Err(response) => return response, - }; + let (has_mcp_tools, mcp_servers) = match ensure_mcp_connection( + &ctx.mcp_orchestrator, + &ctx.mcp_format_registry, + request.tools.as_deref(), + ) + .await + { + Ok(result) => result, + Err(response) => return response, + }; if has_mcp_tools { debug!("MCP tools detected in streaming mode, using streaming tool loop"); diff --git a/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs b/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs index 483092800..1918307a5 100644 --- a/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs +++ b/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs @@ -24,7 +24,10 @@ use super::{ use crate::{ observability::metrics::{metrics_labels, Metrics}, routers::{ - common::mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS}, + common::{ + mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS}, + openai_bridge, + }, error, grpc::common::responses::{ collect_user_function_names, ensure_mcp_connection, persist_response_if_needed, @@ -49,8 +52,12 @@ pub(super) async fn route_responses_internal( let modified_request = load_conversation_history(ctx, &request).await?; // 2. Check MCP connection and get whether MCP tools are present - let (has_mcp_tools, mcp_servers) = - ensure_mcp_connection(&ctx.mcp_orchestrator, request.tools.as_deref()).await?; + let (has_mcp_tools, mcp_servers) = ensure_mcp_connection( + &ctx.mcp_orchestrator, + &ctx.mcp_format_registry, + request.tools.as_deref(), + ) + .await?; let responses_response = if has_mcp_tools { debug!("MCP tools detected, using tool loop"); @@ -228,7 +235,8 @@ pub(super) async fn execute_tool_loop( // Inject MCP metadata into output if state.total_calls > 0 { - session.inject_client_visible_mcp_output_items( + openai_bridge::inject_client_visible_mcp_output_items( + &session, &mut responses_response.output, state.mcp_call_items, &user_function_names, @@ -350,10 +358,14 @@ pub(super) async fn execute_tool_loop( Ok(serde_json::Value::Object(map)) => serde_json::Value::Object(map), _ => json!({}), }; - let response_format = session.tool_response_format(&tc.name); + let response_format = openai_bridge::lookup_tool_format( + &session, + &ctx.mcp_format_registry, + &tc.name, + ); prepare_hosted_dispatch_args( &mut arguments, - &response_format, + response_format, request_tools, request_user, ); @@ -395,7 +407,12 @@ pub(super) async fn execute_tool_loop( ); // Record the call in state with transformed output item - let output_item = result.to_response_item(); + let response_format = openai_bridge::lookup_tool_format( + &session, + &ctx.mcp_format_registry, + &result.tool_name, + ); + let output_item = openai_bridge::transform_tool_output(&result, response_format); let output_str = result.output.to_string(); state.record_call( result.call_id, diff --git a/model_gateway/src/routers/grpc/regular/responses/streaming.rs b/model_gateway/src/routers/grpc/regular/responses/streaming.rs index d269808d5..ad00ae4d6 100644 --- a/model_gateway/src/routers/grpc/regular/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/responses/streaming.rs @@ -34,7 +34,7 @@ use smg_data_connector::{ ConversationItemStorage, ConversationStorage, RequestContext as StorageRequestContext, ResponseStorage, }; -use smg_mcp::{McpServerBinding, McpToolSession, ResponseFormat, ToolExecutionInput}; +use smg_mcp::{McpServerBinding, McpToolSession, ToolExecutionInput}; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, trace, warn}; @@ -50,7 +50,10 @@ use super::{ use crate::{ observability::metrics::{metrics_labels, Metrics}, routers::{ - common::mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS}, + common::{ + mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS}, + openai_bridge::{self, ResponseFormat}, + }, grpc::{ common::responses::{ build_sse_response, persist_response_if_needed, @@ -637,8 +640,11 @@ async fn execute_tool_loop_streaming_internal( tool_call.call_id ); - // Look up response_format for this tool - let response_format = session.tool_response_format(&tool_call.name); + let response_format = openai_bridge::lookup_tool_format( + &session, + &ctx.mcp_format_registry, + &tool_call.name, + ); // Use emitter helpers to determine correct type and allocate index let item_type = @@ -670,7 +676,7 @@ async fn execute_tool_loop_streaming_internal( // Emit tool_call.in_progress let event = - emitter.emit_tool_call_in_progress(output_index, &item_id, &response_format); + emitter.emit_tool_call_in_progress(output_index, &item_id, response_format); emitter.send_event(&event, &tx)?; // Emit arguments events for mcp_call only (skip for builtin tools) @@ -694,7 +700,7 @@ async fn execute_tool_loop_streaming_internal( // Emit searching/interpreting event for builtin tools if let Some(event) = - emitter.emit_tool_call_searching(output_index, &item_id, &response_format) + emitter.emit_tool_call_searching(output_index, &item_id, response_format) { emitter.send_event(&event, &tx)?; } @@ -715,7 +721,7 @@ async fn execute_tool_loop_streaming_internal( }; prepare_hosted_dispatch_args( &mut arguments, - &response_format, + response_format, original_request.tools.as_deref().unwrap_or(&[]), original_request.user.as_deref(), ); @@ -736,7 +742,7 @@ async fn execute_tool_loop_streaming_internal( if success { // Emit tool_call.completed let event = - emitter.emit_tool_call_completed(output_index, &item_id, &response_format); + emitter.emit_tool_call_completed(output_index, &item_id, response_format); emitter.send_event(&event, &tx)?; // Build complete item with output @@ -806,8 +812,8 @@ async fn execute_tool_loop_streaming_internal( }, ); - // Use the centralized tool output transformer from MCP crate output type. - let output_item = tool_output.to_response_item(); + let output_item = + openai_bridge::transform_tool_output(&tool_output, response_format); // Record the call in state with transformed output item state.record_call( diff --git a/model_gateway/src/routers/grpc/router.rs b/model_gateway/src/routers/grpc/router.rs index 524da0773..c1cb982b5 100644 --- a/model_gateway/src/routers/grpc/router.rs +++ b/model_gateway/src/routers/grpc/router.rs @@ -150,6 +150,7 @@ impl GrpcRouter { ctx.conversation_item_storage.clone(), ctx.conversation_memory_writer.clone(), mcp_orchestrator.clone(), + ctx.mcp_format_registry.clone(), storage_request_context.clone(), ) }; @@ -343,6 +344,7 @@ impl GrpcRouter { .conversation_memory_writer .clone(), self.harmony_responses_context.mcp_orchestrator.clone(), + self.harmony_responses_context.mcp_format_registry.clone(), smg_data_connector::current_request_context(), ); diff --git a/model_gateway/src/routers/openai/context.rs b/model_gateway/src/routers/openai/context.rs index 9e94ccd24..f40528f48 100644 --- a/model_gateway/src/routers/openai/context.rs +++ b/model_gateway/src/routers/openai/context.rs @@ -14,7 +14,7 @@ use smg_mcp::{McpOrchestrator, McpToolSession}; use super::provider::Provider; use crate::{ config::RouterConfig, memory::MemoryExecutionContext, middleware, - middleware::TenantRequestMeta, worker::Worker, + middleware::TenantRequestMeta, routers::common::openai_bridge, worker::Worker, }; pub struct RequestContext { @@ -47,6 +47,7 @@ pub struct SharedComponents { pub struct ResponsesComponents { pub shared: Arc, pub mcp_orchestrator: Arc, + pub mcp_format_registry: openai_bridge::FormatRegistry, pub response_storage: Arc, pub conversation_storage: Arc, pub conversation_item_storage: Arc, @@ -81,6 +82,13 @@ impl ComponentRefs { } } + pub fn mcp_format_registry(&self) -> Option<&openai_bridge::FormatRegistry> { + match self { + ComponentRefs::Shared(_) => None, + ComponentRefs::Responses(r) => Some(&r.mcp_format_registry), + } + } + pub fn response_storage(&self) -> Option<&Arc> { match self { ComponentRefs::Shared(_) => None, @@ -322,6 +330,7 @@ pub struct StreamingEventContext<'a> { pub original_request: &'a ResponsesRequest, pub previous_response_id: Option<&'a str>, pub session: Option<&'a McpToolSession<'a>>, + pub mcp_format_registry: Option<&'a openai_bridge::FormatRegistry>, } pub type StreamingRequest = OwnedStreamingContext; diff --git a/model_gateway/src/routers/openai/mcp/tool_loop.rs b/model_gateway/src/routers/openai/mcp/tool_loop.rs index 730cdd430..9016536cd 100644 --- a/model_gateway/src/routers/openai/mcp/tool_loop.rs +++ b/model_gateway/src/routers/openai/mcp/tool_loop.rs @@ -20,10 +20,7 @@ use openai_protocol::{ responses::{generate_id, ResponseInput, ResponseTool, ResponsesRequest}, }; use serde_json::{json, to_value, Value}; -use smg_mcp::{ - extract_embedded_openai_responses, mcp_response_item_id, McpServerBinding, McpToolSession, - ResponseFormat, ResponseTransformer, ToolExecutionInput, ToolExecutionResult, -}; +use smg_mcp::{McpServerBinding, McpToolSession, ToolExecutionInput, ToolExecutionResult}; use tokio::sync::mpsc; use tracing::{debug, info, warn}; @@ -34,6 +31,10 @@ use crate::{ common::{ header_utils::ApiProvider, mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS}, + openai_bridge::{ + self, extract_embedded_openai_responses, mcp_response_item_id, FormatRegistry, + ResponseFormat, ResponseTransformer, + }, }, error, }, @@ -171,6 +172,7 @@ fn build_message_from_openai_response(openai_response: Value) -> Option { pub(crate) async fn execute_streaming_tool_calls( pending_calls: Vec, session: &McpToolSession<'_>, + format_registry: &FormatRegistry, tx: &mpsc::UnboundedSender>, state: &mut ToolLoopState, sequence_number: &mut u64, @@ -198,7 +200,8 @@ pub(crate) async fn execute_streaming_tool_calls( &call.arguments_buffer }; - let response_format = session.tool_response_format(&call.name); + let response_format = + openai_bridge::lookup_tool_format(session, format_registry, &call.name); let server_label = session.resolve_tool_server_label(&call.name); let mut arguments: Value = match serde_json::from_str(args_str) { @@ -209,7 +212,7 @@ pub(crate) async fn execute_streaming_tool_calls( let error_output = json!({ "error": &err_str }); let mut mcp_call_item = build_transformed_mcp_call_item( &error_output, - &response_format, + response_format, &call.call_id, &server_label, &call.name, @@ -218,14 +221,14 @@ pub(crate) async fn execute_streaming_tool_calls( if let Some(obj) = mcp_call_item.as_object_mut() { obj.insert( "id".to_string(), - Value::String(stable_streaming_tool_item_id(&call, &response_format)), + Value::String(stable_streaming_tool_item_id(&call, response_format)), ); } if !send_tool_call_completion_events( tx, &call, &mcp_call_item, - &response_format, + response_format, sequence_number, ) { return false; @@ -242,7 +245,7 @@ pub(crate) async fn execute_streaming_tool_calls( } }; - if !send_tool_call_intermediate_event(tx, &call, &response_format, sequence_number) { + if !send_tool_call_intermediate_event(tx, &call, response_format, sequence_number) { return false; } @@ -256,12 +259,7 @@ pub(crate) async fn execute_streaming_tool_calls( // on image_generation) into dispatch args, then forward the request- // level `user` so a downstream MCP server can attribute per-user usage. // Both steps are no-ops for plain MCP function tools. - prepare_hosted_dispatch_args( - &mut arguments, - &response_format, - request_tools, - request_user, - ); + prepare_hosted_dispatch_args(&mut arguments, response_format, request_tools, request_user); // Log the effective (post-merge) args so the log reflects what the // MCP server actually receives, not the pre-merge string from the model. @@ -286,14 +284,18 @@ pub(crate) async fn execute_streaming_tool_calls( ); let output_str = tool_output.output.to_string(); - let mut mcp_call_item = to_value(tool_output.to_response_item()).unwrap_or_else(|e| { + let mut mcp_call_item = to_value(openai_bridge::transform_tool_output( + &tool_output, + response_format, + )) + .unwrap_or_else(|e| { warn!(tool = %call.name, error = %e, "Failed to convert item to Value"); json!({}) }); if let Some(obj) = mcp_call_item.as_object_mut() { obj.insert( "id".to_string(), - Value::String(stable_streaming_tool_item_id(&call, &response_format)), + Value::String(stable_streaming_tool_item_id(&call, response_format)), ); } @@ -301,7 +303,7 @@ pub(crate) async fn execute_streaming_tool_calls( tx, &call, &mcp_call_item, - &response_format, + response_format, sequence_number, ) { return false; @@ -343,7 +345,7 @@ pub(crate) fn prepare_mcp_tools_as_functions(payload: &mut Value, session: &McpT } } - let session_tools = session.build_function_tools_json(); + let session_tools = openai_bridge::function_tools_json(session); let mut tools_json = Vec::with_capacity(retained_tools.len() + session_tools.len()); tools_json.append(&mut retained_tools); tools_json.extend(session_tools); @@ -413,7 +415,7 @@ pub(crate) fn send_mcp_list_tools_events( sequence_number: &mut u64, server_key: &str, ) -> bool { - let tools_item_full = session.build_mcp_list_tools_json(server_label, server_key); + let tools_item_full = openai_bridge::mcp_list_tools_json(session, server_label, server_key); let item_id = tools_item_full .get("id") .and_then(|v| v.as_str()) @@ -497,7 +499,7 @@ pub(crate) fn send_mcp_list_tools_events( fn send_tool_call_intermediate_event( tx: &mpsc::UnboundedSender>, call: &FunctionCallInProgress, - response_format: &ResponseFormat, + response_format: ResponseFormat, sequence_number: &mut u64, ) -> bool { // Determine event type and ID prefix based on response format @@ -538,7 +540,7 @@ fn send_tool_call_completion_events( tx: &mpsc::UnboundedSender>, call: &FunctionCallInProgress, tool_call_item: &Value, - response_format: &ResponseFormat, + response_format: ResponseFormat, sequence_number: &mut u64, ) -> bool { let effective_output_index = call.effective_output_index(); @@ -591,7 +593,7 @@ fn send_tool_call_completion_events( fn stable_streaming_tool_item_id( call: &FunctionCallInProgress, - response_format: &ResponseFormat, + response_format: ResponseFormat, ) -> String { let source_id = call.item_id.as_deref().unwrap_or(call.call_id.as_str()); @@ -619,7 +621,7 @@ fn normalize_tool_item_id_with_prefix(source_id: &str, target_prefix: &str) -> S .unwrap_or_else(|| format!("{target_prefix}{source_id}")) } -fn non_streaming_tool_item_id_source(item_id: &str, response_format: &ResponseFormat) -> String { +fn non_streaming_tool_item_id_source(item_id: &str, response_format: ResponseFormat) -> String { match response_format { ResponseFormat::Passthrough => item_id.to_string(), ResponseFormat::WebSearchCall @@ -667,7 +669,11 @@ pub(crate) fn inject_mcp_metadata_streaming( let mut prefix = Vec::with_capacity(list_tools_bindings.len() + state.mcp_call_items.len()); for (server_label, server_key) in &list_tools_bindings { if !session.is_internal_server_label(server_label) { - prefix.push(session.build_mcp_list_tools_json(server_label, server_key)); + prefix.push(openai_bridge::mcp_list_tools_json( + session, + server_label, + server_key, + )); } } prefix.extend( @@ -682,7 +688,11 @@ pub(crate) fn inject_mcp_metadata_streaming( let mut output_items = Vec::new(); for (server_label, server_key) in &list_tools_bindings { if !session.is_internal_server_label(server_label) { - output_items.push(session.build_mcp_list_tools_json(server_label, server_key)); + output_items.push(openai_bridge::mcp_list_tools_json( + session, + server_label, + server_key, + )); } } // Use stored transformed items (no reconstruction needed) @@ -776,7 +786,11 @@ fn approval_prefix_items( let mut prefix = Vec::with_capacity(list_tools_bindings.len() + state.mcp_call_items.len() + 1); for (list_server_label, server_key) in list_tools_bindings { if !session.is_internal_server_label(list_server_label) { - prefix.push(session.build_mcp_list_tools_json(list_server_label, server_key)); + prefix.push(openai_bridge::mcp_list_tools_json( + session, + list_server_label, + server_key, + )); } } prefix.extend( @@ -796,6 +810,7 @@ pub(crate) struct ToolLoopExecutionContext<'a> { pub original_body: &'a ResponsesRequest, pub existing_mcp_list_tools_labels: &'a [String], pub session: &'a McpToolSession<'a>, + pub format_registry: &'a FormatRegistry, } /// Execute the tool calling loop @@ -811,6 +826,7 @@ pub(crate) async fn execute_tool_loop( original_body, existing_mcp_list_tools_labels, session, + format_registry, } = tool_loop_ctx; let mut state = ToolLoopState::new( @@ -897,14 +913,15 @@ pub(crate) async fn execute_tool_loop( Err(e) => { warn!(tool = %call.name, error = %e, "Failed to parse tool arguments as JSON"); let error_output = format!("Invalid tool arguments: {e}"); - let response_format = session.tool_response_format(&call.name); + let response_format = + openai_bridge::lookup_tool_format(session, format_registry, &call.name); let server_label = session.resolve_tool_server_label(&call.name); let tool_item_id = - non_streaming_tool_item_id_source(&call.item_id, &response_format); + non_streaming_tool_item_id_source(&call.item_id, response_format); let error_json = json!({ "error": &error_output }); let transformed_item = build_transformed_mcp_call_item( &error_json, - &response_format, + response_format, &tool_item_id, &server_label, &call.name, @@ -939,10 +956,11 @@ pub(crate) async fn execute_tool_loop( // and forward the request-level `user` so a downstream MCP server // can attribute per-user usage. Both steps are no-ops for plain // MCP function tools (Passthrough format). - let response_format = session.tool_response_format(&call.name); + let response_format = + openai_bridge::lookup_tool_format(session, format_registry, &call.name); prepare_hosted_dispatch_args( &mut arguments, - &response_format, + response_format, original_body.tools.as_deref().unwrap_or(&[]), original_body.user.as_deref(), ); @@ -966,7 +984,7 @@ pub(crate) async fn execute_tool_loop( .await; let server_label = session.resolve_tool_server_label(&call.name); - let tool_item_id = non_streaming_tool_item_id_source(&call.item_id, &response_format); + let tool_item_id = non_streaming_tool_item_id_source(&call.item_id, response_format); let approval_request_id = approval_request_item_id_source(&call.item_id); let tool_output = match tool_result { @@ -1006,11 +1024,15 @@ pub(crate) async fn execute_tool_loop( let output_str = tool_output.output.to_string(); let transformed_item = build_transformed_mcp_call_item( &tool_output.output, - &response_format, + response_format, &tool_item_id, &server_label, &call.name, - &call.arguments, + // Use the post-merge string so the client-visible item describes + // what the router actually dispatched (e.g. hosted-tool overrides + // like image-generation `size`/`quality` merged in above), not + // the pre-merge arguments the model emitted. + &effective_arguments, ); state.record_call( @@ -1112,7 +1134,11 @@ fn build_incomplete_response( ); for (server_label, server_key) in &list_tools_bindings { if !session.is_internal_server_label(server_label) { - prefix.push(session.build_mcp_list_tools_json(server_label, server_key)); + prefix.push(openai_bridge::mcp_list_tools_json( + session, + server_label, + server_key, + )); } } prefix.extend( @@ -1207,7 +1233,7 @@ fn build_mcp_approval_request_item( /// Returns the result as a JSON Value for SSE event streaming. fn build_transformed_mcp_call_item( output: &Value, - response_format: &ResponseFormat, + response_format: ResponseFormat, tool_item_id: &str, server_label: &str, tool_name: &str, @@ -1280,7 +1306,7 @@ mod tests { use serde_json::json; use smg_mcp::{ BuiltinToolType, McpConfig, McpOrchestrator, McpServerBinding, McpServerConfig, - McpToolSession, McpTransport, ResponseFormat, Tool, ToolEntry, + McpToolSession, McpTransport, Tool, ToolEntry, }; use tokio::sync::mpsc; @@ -1289,6 +1315,7 @@ mod tests { is_internal_mcp_response_item, mcp_list_tools_bindings_to_emit, ResponseInput, ToolLoopState, }; + use crate::routers::common::openai_bridge::ResponseFormat; fn test_tool(name: &str) -> Tool { let mut schema = serde_json::Map::new(); @@ -1315,7 +1342,7 @@ mod tests { { "url": "https://example.com" } ] }), - &ResponseFormat::WebSearchCall, + ResponseFormat::WebSearchCall, "call_123", "internal-label", "brave_web_search", @@ -1656,7 +1683,7 @@ mod tests { &tx, &call, &tool_call_item, - &ResponseFormat::ImageGenerationCall, + ResponseFormat::ImageGenerationCall, &mut sequence_number, ); assert!(ok, "send_tool_call_completion_events should not disconnect"); @@ -1713,7 +1740,7 @@ mod tests { &tx, &call, &tool_call_item, - &ResponseFormat::WebSearchCall, + ResponseFormat::WebSearchCall, &mut sequence_number, ); assert!(ok); @@ -1769,7 +1796,7 @@ mod tests { &tx, &call, &tool_call_item, - &ResponseFormat::CodeInterpreterCall, + ResponseFormat::CodeInterpreterCall, &mut sequence_number, ); assert!(ok); @@ -1837,7 +1864,7 @@ mod tests { &tx, &call, &tool_call_item, - &ResponseFormat::FileSearchCall, + ResponseFormat::FileSearchCall, &mut sequence_number, ); assert!(ok); diff --git a/model_gateway/src/routers/openai/responses/non_streaming.rs b/model_gateway/src/routers/openai/responses/non_streaming.rs index 8a9959c09..44ae40224 100644 --- a/model_gateway/src/routers/openai/responses/non_streaming.rs +++ b/model_gateway/src/routers/openai/responses/non_streaming.rs @@ -16,6 +16,7 @@ use crate::routers::{ common::{ header_utils::{extract_forwardable_request_headers, ApiProvider}, mcp_utils::ensure_request_mcp_client, + openai_bridge, persistence_utils::persist_conversation_items, }, error, @@ -62,9 +63,19 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response } }; + // The format registry is the router-side source of truth for MCP + // builtin/alias format resolution; falling back to a default would + // silently mis-route hosted tools instead of failing fast. + let mcp_format_registry = match ctx.components.mcp_format_registry() { + Some(r) => r.clone(), + None => { + return error::internal_error("internal_error", "MCP format registry required"); + } + }; + // Check for MCP tools and create session if needed let mcp_servers = if let Some(tools) = original_body.tools.as_deref() { - ensure_request_mcp_client(mcp_orchestrator, tools).await + ensure_request_mcp_client(mcp_orchestrator, &mcp_format_registry, tools).await } else { None }; @@ -84,7 +95,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response forwarded_headers, ); if let Some(tools) = original_body.tools.as_deref() { - session.configure_response_tools_approval(tools); + openai_bridge::configure_response_tools_approval(&mut session, tools); } prepare_mcp_tools_as_functions(&mut payload, &session); @@ -98,6 +109,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response original_body, existing_mcp_list_tools_labels: &existing_mcp_list_tools_labels, session: &session, + format_registry: &mcp_format_registry, }, ) .await diff --git a/model_gateway/src/routers/openai/responses/streaming.rs b/model_gateway/src/routers/openai/responses/streaming.rs index e22e36cf8..2b6a204a3 100644 --- a/model_gateway/src/routers/openai/responses/streaming.rs +++ b/model_gateway/src/routers/openai/responses/streaming.rs @@ -25,9 +25,7 @@ use openai_protocol::{ responses::{ResponseTool, ResponsesRequest}, }; use serde_json::{json, Value}; -use smg_mcp::{ - mcp_response_item_id, McpOrchestrator, McpServerBinding, McpToolSession, ResponseFormat, -}; +use smg_mcp::{McpOrchestrator, McpServerBinding, McpToolSession}; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::warn; @@ -40,6 +38,7 @@ use super::{ rewrite_streaming_block, }, }; +use crate::routers::common::openai_bridge::{self, mcp_response_item_id, ResponseFormat}; const SSE_DONE: &str = "data: [DONE]\n\n"; use crate::{ @@ -147,16 +146,15 @@ pub(super) fn apply_event_transformations_inplace( if let Some(session) = ctx.session.filter(|s| s.has_exposed_tool(&tool_name)) { - let response_format = session.tool_response_format(&tool_name); + let response_format = ctx + .mcp_format_registry + .map(|reg| { + openai_bridge::lookup_tool_format(session, reg, &tool_name) + }) + .unwrap_or(ResponseFormat::Passthrough); - // Determine item type and ID prefix based on response_format - let (new_type, id_prefix) = match response_format { - ResponseFormat::WebSearchCall => (ItemType::WEB_SEARCH_CALL, "ws_"), - ResponseFormat::ImageGenerationCall => { - (ItemType::IMAGE_GENERATION_CALL, "ig_") - } - _ => (ItemType::MCP_CALL, "mcp_"), - }; + let d = openai_bridge::descriptor(response_format); + let (new_type, id_prefix) = (d.type_str, d.id_prefix); item["type"] = json!(new_type); if new_type == ItemType::MCP_CALL { @@ -674,6 +672,7 @@ pub(super) fn handle_streaming_with_tool_interception( headers: Option<&HeaderMap>, req: StreamingRequest, orchestrator: &Arc, + format_registry: openai_bridge::FormatRegistry, mcp_servers: Vec, ) -> Response { let payload = req.payload; @@ -730,6 +729,7 @@ pub(super) fn handle_streaming_with_tool_interception( original_request: &original_request, previous_response_id: previous_response_id.as_deref(), session: Some(&session), + mcp_format_registry: Some(&format_registry), }; let provider = ApiProvider::from_url(&url_clone); let auth_header = @@ -1016,6 +1016,7 @@ pub(super) fn handle_streaming_with_tool_interception( if !execute_streaming_tool_calls( pending_calls, &session, + &format_registry, &tx, &mut state, &mut sequence_number, @@ -1085,10 +1086,18 @@ pub async fn handle_streaming_response(ctx: RequestContext) -> Response { return error::internal_error("internal_error", "MCP orchestrator required"); } }; + // Same fail-fast contract as the non-streaming path: a missing format + // registry means MCP routing decisions would be silently wrong. + let mcp_format_registry = match ctx.components.mcp_format_registry() { + Some(r) => r.clone(), + None => { + return error::internal_error("internal_error", "MCP format registry required"); + } + }; // Check for MCP tools and create request context if needed let mcp_servers = if let Some(tools) = original_body.tools.as_deref() { - ensure_request_mcp_client(&mcp_orchestrator, tools).await + ensure_request_mcp_client(&mcp_orchestrator, &mcp_format_registry, tools).await } else { None }; @@ -1111,6 +1120,7 @@ pub async fn handle_streaming_response(ctx: RequestContext) -> Response { headers.as_ref(), req, &mcp_orchestrator, + mcp_format_registry, mcp_servers, ) } diff --git a/model_gateway/src/routers/openai/responses/utils.rs b/model_gateway/src/routers/openai/responses/utils.rs index 03a2b24a2..61a3588dc 100644 --- a/model_gateway/src/routers/openai/responses/utils.rs +++ b/model_gateway/src/routers/openai/responses/utils.rs @@ -12,7 +12,7 @@ use smg_mcp::McpToolSession; use tracing::warn; use super::common::parse_sse_block; -use crate::routers::common::mcp_utils::collect_user_function_names; +use crate::routers::common::{mcp_utils::collect_user_function_names, openai_bridge}; /// Check if a JSON value is missing, null, or an empty string fn is_missing_or_empty(value: Option<&Value>) -> bool { @@ -312,8 +312,9 @@ fn restore_client_tool_view( } _ => { if let Some(value) = response_tool_to_value(original_tool) { - let should_hide = session - .is_some_and(|s| s.should_hide_tool_json(&value, user_function_names)); + let should_hide = session.is_some_and(|s| { + openai_bridge::should_hide_tool_json(s, &value, user_function_names) + }); if !should_hide { restored_tools.push(value); } @@ -376,7 +377,7 @@ fn strip_internal_mcp_tools( }; tools.retain(|tool| { - !session.is_some_and(|s| s.should_hide_tool_json(tool, user_function_names)) + !session.is_some_and(|s| openai_bridge::should_hide_tool_json(s, tool, user_function_names)) }); } @@ -394,7 +395,9 @@ fn strip_internal_mcp_output_items( }; output.retain(|item| { - !session.is_some_and(|s| s.should_hide_output_item_json(item, user_function_names)) + !session.is_some_and(|s| { + openai_bridge::should_hide_output_item_json(s, item, user_function_names) + }) }); } diff --git a/model_gateway/src/routers/openai/router.rs b/model_gateway/src/routers/openai/router.rs index 0bda10e99..33958f54c 100644 --- a/model_gateway/src/routers/openai/router.rs +++ b/model_gateway/src/routers/openai/router.rs @@ -96,6 +96,7 @@ impl OpenAIRouter { let responses_components = Arc::new(ResponsesComponents { shared: Arc::clone(&shared_components), mcp_orchestrator: mcp_orchestrator.clone(), + mcp_format_registry: ctx.mcp_format_registry.clone(), response_storage: ctx.response_storage.clone(), conversation_storage: ctx.conversation_storage.clone(), conversation_item_storage: ctx.conversation_item_storage.clone(), diff --git a/model_gateway/src/service_discovery.rs b/model_gateway/src/service_discovery.rs index fc35a7997..5d389dc83 100644 --- a/model_gateway/src/service_discovery.rs +++ b/model_gateway/src/service_discovery.rs @@ -1155,7 +1155,7 @@ mod tests { }; use super::*; - use crate::routers::grpc::multimodal::MultimodalConfigRegistry; + use crate::routers::{common::openai_bridge, grpc::multimodal::MultimodalConfigRegistry}; fn create_k8s_pod( name: Option<&str>, @@ -1278,6 +1278,7 @@ mod tests { worker_job_queue: worker_job_queue.clone(), workflow_engines: Arc::new(std::sync::OnceLock::new()), mcp_orchestrator: Arc::new(std::sync::OnceLock::new()), + mcp_format_registry: openai_bridge::FormatRegistry::new(), tokenizer_registry: Arc::new(llm_tokenizer::registry::TokenizerRegistry::new()), multimodal_config_registry: Arc::new(MultimodalConfigRegistry::new()), skill_service: None, diff --git a/model_gateway/src/workflow/mcp_registration.rs b/model_gateway/src/workflow/mcp_registration.rs index fbc966241..f8af6fc9f 100644 --- a/model_gateway/src/workflow/mcp_registration.rs +++ b/model_gateway/src/workflow/mcp_registration.rs @@ -77,6 +77,10 @@ impl StepExecutor for ConnectMcpServerStep { ), })?; + app_context + .mcp_format_registry + .populate_from_server_config(&config_request.config); + // Update active MCP servers metric Metrics::set_mcp_servers_active(mcp_orchestrator.list_servers().len()); diff --git a/model_gateway/tests/mcp_test.rs b/model_gateway/tests/mcp_test.rs index 63a817ad7..0aa84eb81 100644 --- a/model_gateway/tests/mcp_test.rs +++ b/model_gateway/tests/mcp_test.rs @@ -14,11 +14,11 @@ use std::collections::HashMap; use common::mock_mcp_server::{MockMCPServer, MockSearchResponseMCPServer, MockSearchResponseMode}; use openai_protocol::responses::{ResponseOutputItem, WebSearchAction}; use serde_json::json; +use smg::routers::common::openai_bridge::{ResponseFormat, ResponseTransformer}; use smg_mcp::{ core::config::{ResponseFormatConfig, ToolConfig}, - error::McpError, - ApprovalMode, McpConfig, McpOrchestrator, McpServerConfig, McpTransport, TenantContext, - ToolCallResult, + McpConfig, McpOrchestrator, McpServerBinding, McpServerConfig, McpToolSession, McpTransport, + ToolExecutionInput, }; /// Create a new mock server for testing (each test gets its own) @@ -243,48 +243,37 @@ async fn test_tool_execution_with_mock() { let manager = McpOrchestrator::new(config).await.unwrap(); - let request_ctx = manager.create_request_context( + let session = McpToolSession::new( + &manager, + vec![McpServerBinding { + label: "mock_server".to_string(), + server_key: "mock_server".to_string(), + allowed_tools: None, + }], "test-request-1", - TenantContext::default(), - ApprovalMode::PolicyOnly, ); - let result = manager - .call_tool( - "mock_server", - "brave_web_search", - json!({ + let output = session + .execute_tool(ToolExecutionInput { + call_id: "call-1".to_string(), + tool_name: "brave_web_search".to_string(), + arguments: json!({ "query": "rust programming", "count": 1 }), - "mock_server", - &request_ctx, - ) + }) .await; + assert!(!output.is_error, "Tool execution should succeed"); assert!( - result.is_ok(), - "Tool execution should succeed with mock server" + output + .output + .to_string() + .contains("Mock search results for: rust programming"), + "Output should contain mock search results, got: {}", + output.output ); - let response = result.unwrap(); - match response { - ToolCallResult::Success(output_item) => { - // Verify the response is an MCP call with output - match output_item { - ResponseOutputItem::McpCall { output, status, .. } => { - assert_eq!(status, "completed"); - assert!( - output.contains("Mock search results for: rust programming"), - "Output should contain mock search results" - ); - } - _ => panic!("Expected McpCall output item"), - } - } - ToolCallResult::PendingApproval(_) => panic!("Expected Success result"), - } - manager.shutdown().await; } @@ -296,7 +285,7 @@ async fn test_web_search_transform_handles_openai_search_response_with_mock() { "brave_web_search".to_string(), ToolConfig { alias: None, - response_format: ResponseFormatConfig::WebSearchCall, + response_format: Some(ResponseFormatConfig::WebSearchCall), arg_mapping: None, }, ); @@ -325,42 +314,49 @@ async fn test_web_search_transform_handles_openai_search_response_with_mock() { let manager = McpOrchestrator::new(config).await.unwrap(); - let request_ctx = manager.create_request_context( + let session = McpToolSession::new( + &manager, + vec![McpServerBinding { + label: "openai_search_server".to_string(), + server_key: "openai_search_server".to_string(), + allowed_tools: None, + }], "test-request-openai-search", - TenantContext::default(), - ApprovalMode::PolicyOnly, ); - let result = manager - .call_tool( - "openai_search_server", - "brave_web_search", - json!({ - "query": "rust openai search" - }), - "openai_search_server", - &request_ctx, - ) + let output = session + .execute_tool(ToolExecutionInput { + call_id: "call-1".to_string(), + tool_name: "brave_web_search".to_string(), + arguments: json!({ "query": "rust openai search" }), + }) .await; - assert!(result.is_ok(), "Tool execution should succeed"); + assert!(!output.is_error, "Tool execution should succeed"); - match result.unwrap() { - ToolCallResult::Success(ResponseOutputItem::WebSearchCall { action, .. }) => match action { + // The session returns the raw `output` Value from the MCP call. Re-transform + // with WebSearchCall format to verify serialization (end-to-end source + // extraction is covered by the gateway bridge's own tests). + let transformed = ResponseTransformer::transform( + &output.output, + ResponseFormat::WebSearchCall, + "test-request-openai-search", + "openai_search_server", + "brave_web_search", + "{\"query\":\"rust openai search\"}", + ); + match transformed { + ResponseOutputItem::WebSearchCall { action, .. } => match action { WebSearchAction::Search { query, queries: _, - sources, + sources: _, } => { assert_eq!(query, Some("rust openai search".to_string())); - assert_eq!(sources.len(), 1); - assert_eq!(sources[0].source_type, "url"); - assert_eq!(sources[0].url, "https://example.com/openai-result"); } _ => panic!("Expected Search action"), }, - ToolCallResult::Success(other) => panic!("Expected WebSearchCall, got {other:?}"), - ToolCallResult::PendingApproval(_) => panic!("Expected Success result"), + other => panic!("Expected WebSearchCall, got {other:?}"), } manager.shutdown().await; @@ -374,7 +370,7 @@ async fn test_web_search_transform_sets_action_query_for_brave_search_with_mock( "brave_web_search".to_string(), ToolConfig { alias: None, - response_format: ResponseFormatConfig::WebSearchCall, + response_format: Some(ResponseFormatConfig::WebSearchCall), arg_mapping: None, }, ); @@ -403,28 +399,36 @@ async fn test_web_search_transform_sets_action_query_for_brave_search_with_mock( let manager = McpOrchestrator::new(config).await.unwrap(); - let request_ctx = manager.create_request_context( + let session = McpToolSession::new( + &manager, + vec![McpServerBinding { + label: "brave_response_server".to_string(), + server_key: "brave_response_server".to_string(), + allowed_tools: None, + }], "test-request-brave", - TenantContext::default(), - ApprovalMode::PolicyOnly, ); - let result = manager - .call_tool( - "brave_response_server", - "brave_web_search", - json!({ - "query": "rust brave query" - }), - "brave_response_server", - &request_ctx, - ) + let output = session + .execute_tool(ToolExecutionInput { + call_id: "call-1".to_string(), + tool_name: "brave_web_search".to_string(), + arguments: json!({ "query": "rust brave query" }), + }) .await; - assert!(result.is_ok(), "Tool execution should succeed"); + assert!(!output.is_error, "Tool execution should succeed"); - match result.unwrap() { - ToolCallResult::Success(ResponseOutputItem::WebSearchCall { action, .. }) => match action { + let transformed = ResponseTransformer::transform( + &output.output, + ResponseFormat::WebSearchCall, + "test-request-brave", + "brave_response_server", + "brave_web_search", + "{\"query\":\"rust brave query\"}", + ); + match transformed { + ResponseOutputItem::WebSearchCall { action, .. } => match action { WebSearchAction::Search { query, queries: _, @@ -434,8 +438,7 @@ async fn test_web_search_transform_sets_action_query_for_brave_search_with_mock( } _ => panic!("Expected Search action"), }, - ToolCallResult::Success(other) => panic!("Expected WebSearchCall, got {other:?}"), - ToolCallResult::PendingApproval(_) => panic!("Expected Success result"), + other => panic!("Expected WebSearchCall, got {other:?}"), } manager.shutdown().await; } @@ -468,38 +471,34 @@ async fn test_concurrent_tool_execution() { let manager = McpOrchestrator::new(config).await.unwrap(); - let request_ctx = manager.create_request_context( + let session = McpToolSession::new( + &manager, + vec![McpServerBinding { + label: "mock_server".to_string(), + server_key: "mock_server".to_string(), + allowed_tools: None, + }], "test-concurrent", - TenantContext::default(), - ApprovalMode::PolicyOnly, ); - // Execute tools sequentially (true concurrent execution would require Arc) let tool_calls = vec![ ("brave_web_search", json!({"query": "test1"})), ("brave_local_search", json!({"query": "test2"})), ]; for (tool_name, args) in tool_calls { - let result = manager - .call_tool("mock_server", tool_name, args, "mock_server", &request_ctx) + let output = session + .execute_tool(ToolExecutionInput { + call_id: format!("call-{tool_name}"), + tool_name: tool_name.to_string(), + arguments: args, + }) .await; - - assert!(result.is_ok(), "Tool {tool_name} should succeed"); - let response = result.unwrap(); - match response { - ToolCallResult::Success(output_item) => { - // Verify the response is an MCP call with output - match output_item { - ResponseOutputItem::McpCall { status, output, .. } => { - assert_eq!(status, "completed"); - assert!(!output.is_empty(), "Should have output content"); - } - _ => panic!("Expected McpCall output item"), - } - } - ToolCallResult::PendingApproval(_) => panic!("Expected Success result"), - } + assert!(!output.is_error, "Tool {tool_name} should succeed"); + assert!( + !output.output.to_string().is_empty(), + "Should have output content" + ); } manager.shutdown().await; @@ -535,31 +534,28 @@ async fn test_tool_execution_errors() { let manager = McpOrchestrator::new(config).await.unwrap(); - let request_ctx = manager.create_request_context( + let session = McpToolSession::new( + &manager, + vec![McpServerBinding { + label: "mock_server".to_string(), + server_key: "mock_server".to_string(), + allowed_tools: None, + }], "test-error", - TenantContext::default(), - ApprovalMode::PolicyOnly, ); - // Try to call unknown tool - let result = manager - .call_tool( - "mock_server", - "unknown_tool", - json!({}), - "mock_server", - &request_ctx, - ) + let output = session + .execute_tool(ToolExecutionInput { + call_id: "call-1".to_string(), + tool_name: "unknown_tool".to_string(), + arguments: json!({}), + }) .await; - assert!(result.is_err(), "Should fail for unknown tool"); - - match result.unwrap_err() { - McpError::ToolNotFound(name) => { - // Error message now includes qualified name (server_key:tool_name) - assert_eq!(name, "mock_server:unknown_tool"); - } - _ => panic!("Expected ToolNotFound error"), - } + assert!( + output.is_error, + "Unknown tool should produce an error output" + ); + assert_eq!(output.tool_name, "unknown_tool"); manager.shutdown().await; } @@ -792,40 +788,32 @@ async fn test_complete_workflow() { assert!(!manager.has_tool("integration_test", "nonexistent_tool")); // 6. Execute a tool - let request_ctx = manager.create_request_context( + let session = McpToolSession::new( + &manager, + vec![McpServerBinding { + label: "integration_test".to_string(), + server_key: "integration_test".to_string(), + allowed_tools: None, + }], "test-workflow", - TenantContext::default(), - ApprovalMode::PolicyOnly, ); - let result = manager - .call_tool( - "integration_test", - "brave_web_search", - json!({ + let output = session + .execute_tool(ToolExecutionInput { + call_id: "call-1".to_string(), + tool_name: "brave_web_search".to_string(), + arguments: json!({ "query": "SGLang router MCP integration", "count": 1 }), - "integration_test", - &request_ctx, - ) + }) .await; - assert!(result.is_ok(), "Tool execution should succeed"); - let response = result.unwrap(); - match response { - ToolCallResult::Success(output_item) => { - // Verify the response is an MCP call with output - match output_item { - ResponseOutputItem::McpCall { status, output, .. } => { - assert_eq!(status, "completed"); - assert!(!output.is_empty(), "Should return output content"); - } - _ => panic!("Expected McpCall output item"), - } - } - ToolCallResult::PendingApproval(_) => panic!("Expected Success result"), - } + assert!(!output.is_error, "Tool execution should succeed"); + assert!( + !output.output.to_string().is_empty(), + "Should return output content" + ); // 7. Clean shutdown manager.shutdown().await;