diff --git a/model_gateway/src/routers/openai/context.rs b/model_gateway/src/routers/openai/context.rs index 9e94ccd24..aa8e1c4a1 100644 --- a/model_gateway/src/routers/openai/context.rs +++ b/model_gateway/src/routers/openai/context.rs @@ -11,7 +11,10 @@ use smg_data_connector::{ }; use smg_mcp::{McpOrchestrator, McpToolSession}; -use super::provider::Provider; +use super::{ + provider::Provider, + stateful_tools::{SharedStatefulToolBootstrapper, StatefulToolBootstrapState}, +}; use crate::{ config::RouterConfig, memory::MemoryExecutionContext, middleware, middleware::TenantRequestMeta, worker::Worker, @@ -47,6 +50,7 @@ pub struct SharedComponents { pub struct ResponsesComponents { pub shared: Arc, pub mcp_orchestrator: Arc, + pub stateful_tool_bootstrapper: SharedStatefulToolBootstrapper, pub response_storage: Arc, pub conversation_storage: Arc, pub conversation_item_storage: Arc, @@ -88,6 +92,13 @@ impl ComponentRefs { } } + pub fn stateful_tool_bootstrapper(&self) -> Option<&SharedStatefulToolBootstrapper> { + match self { + ComponentRefs::Shared(_) => None, + ComponentRefs::Responses(r) => Some(&r.stateful_tool_bootstrapper), + } + } + pub fn conversation_storage(&self) -> Option<&Arc> { match self { ComponentRefs::Shared(_) => None, @@ -130,8 +141,10 @@ pub struct PayloadState { #[derive(Default)] pub struct ResponsesPayloadState { + pub client_request: Option>, pub previous_response_id: Option, pub existing_mcp_list_tools_labels: Vec, + pub stateful_tool_bootstrap: StatefulToolBootstrapState, } impl RequestContext { @@ -265,9 +278,11 @@ pub struct StorageHandles { pub struct OwnedStreamingContext { pub url: String, pub payload: Value, - pub original_body: ResponsesRequest, + pub request_body: Arc, + pub client_body: Arc, pub previous_response_id: Option, pub existing_mcp_list_tools_labels: Vec, + pub stateful_tool_bootstrap: StatefulToolBootstrapState, pub storage: StorageHandles, } @@ -275,10 +290,12 @@ impl RequestContext { pub fn into_streaming_context(mut self) -> Result { let payload_state = self.take_payload().ok_or("Payload not prepared")?; let responses_payload_state = self.take_responses_payload().unwrap_or_default(); - let original_body = self - .responses_request() - .ok_or("Expected responses request")? - .clone(); + let request_body = self + .responses_request_arc() + .ok_or("Expected responses request")?; + let client_body = responses_payload_state + .client_request + .unwrap_or_else(|| Arc::clone(&request_body)); let response = self .components .response_storage() @@ -303,9 +320,11 @@ impl RequestContext { Ok(OwnedStreamingContext { url: payload_state.url, payload: payload_state.json, - original_body, + request_body, + client_body, previous_response_id: responses_payload_state.previous_response_id, existing_mcp_list_tools_labels: responses_payload_state.existing_mcp_list_tools_labels, + stateful_tool_bootstrap: responses_payload_state.stateful_tool_bootstrap, storage: StorageHandles { response, conversation, diff --git a/model_gateway/src/routers/openai/mcp/tool_loop.rs b/model_gateway/src/routers/openai/mcp/tool_loop.rs index 730cdd430..036598087 100644 --- a/model_gateway/src/routers/openai/mcp/tool_loop.rs +++ b/model_gateway/src/routers/openai/mcp/tool_loop.rs @@ -36,6 +36,7 @@ use crate::{ mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS}, }, error, + openai::stateful_tools::StatefulToolBootstrapState, }, }; @@ -49,6 +50,8 @@ pub(crate) struct ToolLoopState { pub conversation_history: Vec, /// Original user input (preserved for building resume payloads) pub original_input: ResponseInput, + /// Request-scoped prepared state for hosted/stateful tools. + pub stateful_tool_bootstrap: StatefulToolBootstrapState, /// MCP bindings already represented by historical `mcp_list_tools` items. pub existing_mcp_list_tools_labels: HashSet, /// Transformed output items (mcp_call, web_search_call, etc.) - stored to avoid reconstruction @@ -56,7 +59,20 @@ pub(crate) struct ToolLoopState { } impl ToolLoopState { + #[cfg(test)] pub fn new(original_input: ResponseInput, prior_mcp_list_tools_labels: Vec) -> Self { + Self::new_with_bootstrap( + original_input, + prior_mcp_list_tools_labels, + StatefulToolBootstrapState::default(), + ) + } + + pub fn new_with_bootstrap( + original_input: ResponseInput, + prior_mcp_list_tools_labels: Vec, + stateful_tool_bootstrap: StatefulToolBootstrapState, + ) -> Self { let known_labels = prior_mcp_list_tools_labels .into_iter() .collect::>(); @@ -66,6 +82,7 @@ impl ToolLoopState { total_calls: 0, conversation_history: Vec::new(), original_input, + stateful_tool_bootstrap, existing_mcp_list_tools_labels: known_labels, mcp_call_items: Vec::new(), } @@ -795,6 +812,7 @@ fn approval_prefix_items( pub(crate) struct ToolLoopExecutionContext<'a> { pub original_body: &'a ResponsesRequest, pub existing_mcp_list_tools_labels: &'a [String], + pub stateful_tool_bootstrap: &'a StatefulToolBootstrapState, pub session: &'a McpToolSession<'a>, } @@ -810,12 +828,14 @@ pub(crate) async fn execute_tool_loop( let ToolLoopExecutionContext { original_body, existing_mcp_list_tools_labels, + stateful_tool_bootstrap, session, } = tool_loop_ctx; - let mut state = ToolLoopState::new( + let mut state = ToolLoopState::new_with_bootstrap( original_body.input.clone(), existing_mcp_list_tools_labels.to_vec(), + stateful_tool_bootstrap.clone(), ); let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize); let base_payload = initial_payload.clone(); @@ -823,8 +843,10 @@ pub(crate) async fn execute_tool_loop( let mut current_payload = initial_payload; info!( - "Starting tool loop: max_tool_calls={:?}, max_iterations={}", - max_tool_calls, DEFAULT_MAX_ITERATIONS + "Starting tool loop: max_tool_calls={:?}, max_iterations={}, prepared_stateful_tools={}", + max_tool_calls, + DEFAULT_MAX_ITERATIONS, + state.stateful_tool_bootstrap.prepared_tools.len() ); let provider = ApiProvider::from_url(url); let auth_header = provider.extract_auth_header(headers, worker_api_key); diff --git a/model_gateway/src/routers/openai/mod.rs b/model_gateway/src/routers/openai/mod.rs index e8361015a..30921ee04 100644 --- a/model_gateway/src/routers/openai/mod.rs +++ b/model_gateway/src/routers/openai/mod.rs @@ -15,5 +15,6 @@ mod provider; pub mod realtime; pub mod responses; mod router; +pub mod stateful_tools; pub use router::OpenAIRouter; diff --git a/model_gateway/src/routers/openai/responses/history.rs b/model_gateway/src/routers/openai/responses/history.rs index b457e1c2c..c577aaac3 100644 --- a/model_gateway/src/routers/openai/responses/history.rs +++ b/model_gateway/src/routers/openai/responses/history.rs @@ -249,6 +249,7 @@ fn deserialize_items_from_array(array: &Value) -> Vec { .map_err(|e| warn!("Failed to deserialize item: {}. Item: {}", e, item)) .ok() }) + .filter(|item| !matches!(item, ResponseInputOutputItem::Reasoning { .. })) .collect() }) .unwrap_or_default() @@ -321,9 +322,10 @@ pub(crate) fn inject_memory_context( #[cfg(test)] mod tests { - use openai_protocol::responses::{ResponseInput, ResponsesRequest}; + use openai_protocol::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest}; + use serde_json::json; - use super::inject_memory_context; + use super::{deserialize_items_from_array, inject_memory_context}; use crate::routers::common::header_utils::{ ConversationMemoryConfig, LongTermMemoryConfig, ShortTermMemoryConfig, }; @@ -357,4 +359,33 @@ mod tests { } } } + + #[test] + fn deserialize_items_from_array_drops_reasoning_replay_items() { + let reasoning = json!({ + "type": "reasoning", + "id": "rs_1", + "summary": [] + }); + assert!( + serde_json::from_value::(reasoning.clone()).is_ok(), + "reasoning fixture must deserialize so this test validates filtering logic" + ); + + let items = json!([ + { + "type": "message", + "id": "msg_1", + "role": "user", + "content": [{ "type": "input_text", "text": "hello" }], + "status": "completed" + }, + reasoning + ]); + + let parsed = deserialize_items_from_array(&items); + + assert_eq!(parsed.len(), 1); + assert!(matches!(parsed[0], ResponseInputOutputItem::Message { .. })); + } } diff --git a/model_gateway/src/routers/openai/responses/non_streaming.rs b/model_gateway/src/routers/openai/responses/non_streaming.rs index 8a9959c09..154dfe2e2 100644 --- a/model_gateway/src/routers/openai/responses/non_streaming.rs +++ b/model_gateway/src/routers/openai/responses/non_streaming.rs @@ -11,7 +11,9 @@ use serde_json::Value; use smg_mcp::McpToolSession; use tracing::warn; -use super::utils::{patch_response_with_request_metadata, restore_original_tools}; +use super::utils::{ + build_persistence_request_body, patch_response_with_request_metadata, restore_original_tools, +}; use crate::routers::{ common::{ header_utils::{extract_forwardable_request_headers, ApiProvider}, @@ -39,16 +41,19 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response url, } = payload_state; let ResponsesPayloadState { + client_request, previous_response_id, existing_mcp_list_tools_labels, + stateful_tool_bootstrap, } = ctx.take_responses_payload().unwrap_or_default(); - let original_body = match ctx.responses_request() { + let request_body = match ctx.responses_request() { Some(r) => r, None => { return error::internal_error("internal_error", "Expected responses request"); } }; + let client_body = client_request.as_deref().unwrap_or(request_body); let worker = match ctx.worker() { Some(w) => w.clone(), None => { @@ -63,7 +68,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response }; // Check for MCP tools and create session if needed - let mcp_servers = if let Some(tools) = original_body.tools.as_deref() { + let mcp_servers = if let Some(tools) = request_body.tools.as_deref() { ensure_request_mcp_client(mcp_orchestrator, tools).await } else { None @@ -72,7 +77,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response let mut response_json: Value; if let Some(mcp_servers) = mcp_servers { - let session_request_id = original_body + let session_request_id = request_body .request_id .clone() .unwrap_or_else(|| format!("req_{}", uuid::Uuid::now_v7())); @@ -83,7 +88,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response &session_request_id, forwarded_headers, ); - if let Some(tools) = original_body.tools.as_deref() { + if let Some(tools) = request_body.tools.as_deref() { session.configure_response_tools_approval(tools); } prepare_mcp_tools_as_functions(&mut payload, &session); @@ -95,8 +100,9 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response worker.api_key(), payload, ToolLoopExecutionContext { - original_body, + original_body: request_body, existing_mcp_list_tools_labels: &existing_mcp_list_tools_labels, + stateful_tool_bootstrap: &stateful_tool_bootstrap, session: &session, }, ) @@ -112,7 +118,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response } } - restore_original_tools(&mut response_json, original_body, Some(&session)); + restore_original_tools(&mut response_json, client_body, Some(&session)); } else { let mut request_builder = ctx.components.client().post(&url).json(&payload); let provider = ApiProvider::from_url(&url); @@ -156,11 +162,11 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response } }; - restore_original_tools(&mut response_json, original_body, None); + restore_original_tools(&mut response_json, client_body, None); } patch_response_with_request_metadata( &mut response_json, - original_body, + client_body, previous_response_id.as_deref(), ); @@ -169,12 +175,13 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response ctx.components.conversation_item_storage(), ctx.components.response_storage(), ) { + let persistence_body = build_persistence_request_body(request_body, client_body); if let Err(err) = persist_conversation_items( conv_storage.clone(), item_storage.clone(), resp_storage.clone(), &response_json, - original_body, + &persistence_body, ctx.storage_request_context.clone(), ) .await diff --git a/model_gateway/src/routers/openai/responses/route.rs b/model_gateway/src/routers/openai/responses/route.rs index cf131cae0..eb6b3bc54 100644 --- a/model_gateway/src/routers/openai/responses/route.rs +++ b/model_gateway/src/routers/openai/responses/route.rs @@ -13,11 +13,12 @@ use serde_json::to_value; use super::{ super::{ context::{ - ComponentRefs, PayloadState, RequestContext, ResponsesComponents, + ComponentRefs, PayloadState, RequestContext, RequestType, ResponsesComponents, ResponsesPayloadState, WorkerSelection, }, provider::ProviderRegistry, router::resolve_provider, + stateful_tools::{ensure_stateful_tool_bootstrap, StatefulToolBootstrapContext}, }, handle_non_streaming_response, handle_streaming_response, }; @@ -132,12 +133,34 @@ pub(in crate::routers::openai) async fn route_responses( super::history::inject_memory_context(&memory_config, &mut request_body); } + let client_request = Arc::new(body.clone()); + let mut ctx = RequestContext::for_responses( + Arc::clone(&client_request), + headers.cloned(), + Some(model_id.to_string()), + ComponentRefs::Responses(Arc::clone(deps.responses_components)), + ); + ctx.storage_request_context = smg_data_connector::current_request_context(); + ctx.tenant_request_meta = Some(tenant_meta.clone()); + let provider = resolve_provider(deps.provider_registry, worker.as_ref(), model); + + ctx.state.worker = Some(WorkerSelection { + worker: Arc::clone(&worker), + provider: Arc::clone(&provider), + }); + ctx.state.responses_payload = Some(ResponsesPayloadState { + client_request: Some(Arc::clone(&client_request)), + previous_response_id: loaded_history.previous_response_id, + existing_mcp_list_tools_labels: loaded_history.existing_mcp_list_tools_labels, + ..Default::default() + }); + request_body.store = Some(false); if let ResponseInput::Items(ref mut items) = request_body.input { items.retain(|item| !matches!(item, ResponseInputOutputItem::Reasoning { .. })); } - let mut payload = match to_value(&request_body) { + let mut preflight_payload = match to_value(&request_body) { Ok(v) => v, Err(e) => { Metrics::record_router_error( @@ -155,8 +178,7 @@ pub(in crate::routers::openai) async fn route_responses( } }; - let provider = resolve_provider(deps.provider_registry, worker.as_ref(), model); - if let Err(e) = provider.transform_request(&mut payload, Endpoint::Responses) { + if let Err(e) = provider.transform_request(&mut preflight_payload, Endpoint::Responses) { Metrics::record_router_error( metrics_labels::ROUTER_OPENAI, metrics_labels::BACKEND_EXTERNAL, @@ -168,28 +190,108 @@ pub(in crate::routers::openai) async fn route_responses( return error::bad_request("invalid_request", format!("Provider transform error: {e}")); } - let mut ctx = RequestContext::for_responses( - Arc::new(body.clone()), - headers.cloned(), - Some(model_id.to_string()), - ComponentRefs::Responses(Arc::clone(deps.responses_components)), - ); - ctx.storage_request_context = smg_data_connector::current_request_context(); - ctx.tenant_request_meta = Some(tenant_meta.clone()); + let stateful_tool_bootstrapper = match ctx.components.stateful_tool_bootstrapper() { + Some(bootstrapper) => Arc::clone(bootstrapper), + None => { + Metrics::record_router_error( + metrics_labels::ROUTER_OPENAI, + metrics_labels::BACKEND_EXTERNAL, + metrics_labels::CONNECTION_HTTP, + model, + metrics_labels::ENDPOINT_RESPONSES, + metrics_labels::ERROR_INTERNAL, + ); + return error::internal_error( + "internal_error", + "Stateful tool bootstrapper not configured", + ); + } + }; + let storage_request_context = ctx.storage_request_context.clone(); + let tenant_request_meta = ctx.tenant_request_meta.clone(); + let memory_execution_context = ctx.memory_execution_context.clone(); + let bootstrap_state = match ctx.state.responses_payload.as_mut() { + Some(responses_payload) => &mut responses_payload.stateful_tool_bootstrap, + None => { + Metrics::record_router_error( + metrics_labels::ROUTER_OPENAI, + metrics_labels::BACKEND_EXTERNAL, + metrics_labels::CONNECTION_HTTP, + model, + metrics_labels::ENDPOINT_RESPONSES, + metrics_labels::ERROR_INTERNAL, + ); + return error::internal_error( + "internal_error", + "Responses payload state not initialized", + ); + } + }; + if let Err(e) = ensure_stateful_tool_bootstrap( + &mut request_body, + bootstrap_state, + stateful_tool_bootstrapper.as_ref(), + StatefulToolBootstrapContext { + headers, + storage_request_context: storage_request_context.as_ref(), + memory_execution_context: &memory_execution_context, + tenant_request_meta: tenant_request_meta.as_ref(), + }, + ) + .await + { + Metrics::record_router_error( + metrics_labels::ROUTER_OPENAI, + metrics_labels::BACKEND_EXTERNAL, + metrics_labels::CONNECTION_HTTP, + model, + metrics_labels::ENDPOINT_RESPONSES, + metrics_labels::ERROR_INTERNAL, + ); + return error::internal_error( + "stateful_tool_bootstrap_failed", + format!("Failed to prepare stateful tool request state: {e}"), + ); + } + if let ResponseInput::Items(ref mut items) = request_body.input { + items.retain(|item| !matches!(item, ResponseInputOutputItem::Reasoning { .. })); + } + ctx.input.request_type = RequestType::Responses(Arc::new(request_body.clone())); - ctx.state.worker = Some(WorkerSelection { - worker: Arc::clone(&worker), - provider: Arc::clone(&provider), - }); + let mut payload = match to_value(&request_body) { + Ok(v) => v, + Err(e) => { + Metrics::record_router_error( + metrics_labels::ROUTER_OPENAI, + metrics_labels::BACKEND_EXTERNAL, + metrics_labels::CONNECTION_HTTP, + model, + metrics_labels::ENDPOINT_RESPONSES, + metrics_labels::ERROR_VALIDATION, + ); + return error::bad_request( + "invalid_request", + format!("Failed to serialize request: {e}"), + ); + } + }; + + if let Err(e) = provider.transform_request(&mut payload, Endpoint::Responses) { + Metrics::record_router_error( + metrics_labels::ROUTER_OPENAI, + metrics_labels::BACKEND_EXTERNAL, + metrics_labels::CONNECTION_HTTP, + model, + metrics_labels::ENDPOINT_RESPONSES, + metrics_labels::ERROR_VALIDATION, + ); + return error::bad_request("invalid_request", format!("Provider transform error: {e}")); + } ctx.state.payload = Some(PayloadState { json: payload, url: format!("{}/v1/responses", worker.url()), }); - ctx.state.responses_payload = Some(ResponsesPayloadState { - previous_response_id: loaded_history.previous_response_id, - existing_mcp_list_tools_labels: loaded_history.existing_mcp_list_tools_labels, - }); let response = if ctx.is_streaming() { handle_streaming_response(ctx).await diff --git a/model_gateway/src/routers/openai/responses/streaming.rs b/model_gateway/src/routers/openai/responses/streaming.rs index e22e36cf8..a544cdf85 100644 --- a/model_gateway/src/routers/openai/responses/streaming.rs +++ b/model_gateway/src/routers/openai/responses/streaming.rs @@ -36,8 +36,8 @@ use super::{ accumulator::StreamingResponseAccumulator, common::{extract_output_index, get_event_type, parse_sse_block, ChunkProcessor}, utils::{ - patch_response_with_request_metadata, response_tool_to_value, restore_original_tools, - rewrite_streaming_block, + build_persistence_request_body, patch_response_with_request_metadata, + response_tool_to_value, restore_original_tools, rewrite_streaming_block, }, }; const SSE_DONE: &str = "data: [DONE]\n\n"; @@ -559,8 +559,9 @@ pub(super) async fn handle_simple_streaming_passthrough( let (tx, rx) = mpsc::unbounded_channel::>(); - let should_store = req.original_body.store.unwrap_or(true); - let original_request = req.original_body; + let should_store = req.client_body.store.unwrap_or(true); + let client_request = req.client_body; + let request_body = req.request_body; let previous_response_id = req.previous_response_id; let storage = req.storage; @@ -582,7 +583,7 @@ pub(super) async fn handle_simple_streaming_passthrough( while let Some(raw_block) = chunk_processor.next_block() { let block_cow = match rewrite_streaming_block( &raw_block, - &original_request, + &client_request, previous_response_id.as_deref(), ) { Some(modified) => Cow::Owned(modified), @@ -626,17 +627,19 @@ pub(super) async fn handle_simple_streaming_passthrough( if let Some(mut response_json) = accumulator.into_final_response() { patch_response_with_request_metadata( &mut response_json, - &original_request, + &client_request, previous_response_id.as_deref(), ); // Always persist conversation items and response (even without conversation) + let persistence_body = + build_persistence_request_body(&request_body, &client_request); if let Err(err) = persist_conversation_items( storage.conversation.clone(), storage.conversation_item.clone(), storage.response.clone(), &response_json, - &original_request, + &persistence_body, storage.request_context.clone(), ) .await @@ -679,10 +682,12 @@ pub(super) fn handle_streaming_with_tool_interception( let payload = req.payload; let (tx, rx) = mpsc::unbounded_channel::>(); - let should_store = req.original_body.store.unwrap_or(true); - let original_request = req.original_body; + let should_store = req.client_body.store.unwrap_or(true); + let client_request = req.client_body; + let request_body = req.request_body; let previous_response_id = req.previous_response_id; let existing_mcp_list_tools_labels = req.existing_mcp_list_tools_labels; + let stateful_tool_bootstrap = req.stateful_tool_bootstrap; let url = req.url; let storage = req.storage; @@ -698,11 +703,16 @@ pub(super) fn handle_streaming_with_tool_interception( reason = "fire-and-forget MCP tool loop; gateway shutdown need not wait for individual tool loops" )] tokio::spawn(async move { - let mut state = ToolLoopState::new( - original_request.input.clone(), + let mut state = ToolLoopState::new_with_bootstrap( + request_body.input.clone(), existing_mcp_list_tools_labels, + stateful_tool_bootstrap, ); - let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize); + tracing::debug!( + prepared_stateful_tools = state.stateful_tool_bootstrap.prepared_tools.len(), + "Starting streaming tool loop" + ); + let max_tool_calls = request_body.max_tool_calls.map(|n| n as usize); // Create session inside spawned task (borrows from orchestrator_clone which lives in closure) let session_request_id = format!("resp_{}", uuid::Uuid::now_v7()); @@ -727,7 +737,7 @@ pub(super) fn handle_streaming_with_tool_interception( ); let streaming_ctx = StreamingEventContext { - original_request: &original_request, + original_request: &client_request, previous_response_id: previous_response_id.as_deref(), session: Some(&session), }; @@ -951,20 +961,22 @@ pub(super) fn handle_streaming_with_tool_interception( } inject_mcp_metadata_streaming(&mut response_json, &state, &session); - restore_original_tools(&mut response_json, &original_request, Some(&session)); + restore_original_tools(&mut response_json, &client_request, Some(&session)); patch_response_with_request_metadata( &mut response_json, - &original_request, + &client_request, previous_response_id.as_deref(), ); // Always persist conversation items and response (even without conversation) + let persistence_body = + build_persistence_request_body(&request_body, &client_request); if let Err(err) = persist_conversation_items( storage.conversation.clone(), storage.conversation_item.clone(), storage.response.clone(), &response_json, - &original_request, + &persistence_body, storage.request_context.clone(), ) .await @@ -988,7 +1000,7 @@ pub(super) fn handle_streaming_with_tool_interception( state.total_calls += pending_calls.len(); // Record tool loop iteration metric - Metrics::record_mcp_tool_iteration(&original_request.model); + Metrics::record_mcp_tool_iteration(&request_body.model); let effective_limit = match max_tool_calls { Some(user_max) => user_max.min(DEFAULT_MAX_ITERATIONS), @@ -1019,9 +1031,9 @@ pub(super) fn handle_streaming_with_tool_interception( &tx, &mut state, &mut sequence_number, - &original_request.model, - original_request.tools.as_deref().unwrap_or(&[]), - original_request.user.as_deref(), + &request_body.model, + request_body.tools.as_deref().unwrap_or(&[]), + request_body.user.as_deref(), ) .await { diff --git a/model_gateway/src/routers/openai/responses/utils.rs b/model_gateway/src/routers/openai/responses/utils.rs index 03a2b24a2..7974459dc 100644 --- a/model_gateway/src/routers/openai/responses/utils.rs +++ b/model_gateway/src/routers/openai/responses/utils.rs @@ -22,6 +22,35 @@ fn is_missing_or_empty(value: Option<&Value>) -> bool { } } +/// Build the request view used for SMG persistence. +/// +/// Provider execution uses `request_body` after history loading, bootstrap +/// injection, `store=false`, and replay sanitization. Persistence stores the +/// caller turn from `client_body` so replay-expanded history and request-scoped +/// bootstrap/memory context are not persisted into future turns. Other +/// execution-normalized fields still come from `request_body`, while +/// caller-owned metadata such as `conversation`, `previous_response_id`, and +/// `store` are restored from `client_body`. +pub(super) fn build_persistence_request_body( + request_body: &ResponsesRequest, + client_body: &ResponsesRequest, +) -> ResponsesRequest { + let mut persistence_body = request_body.clone(); + // Persist only the client turn. The normalized request body may contain + // replayed history or request-scoped bootstrap context that must not be + // stored and replayed again on the next request. + persistence_body.input = client_body.input.clone(); + persistence_body + .conversation + .clone_from(&client_body.conversation); + persistence_body + .previous_response_id + .clone_from(&client_body.previous_response_id); + persistence_body.store = client_body.store; + persistence_body.user.clone_from(&client_body.user); + persistence_body +} + /// Insert a string value into a JSON object if the condition is met fn insert_if(obj: &mut Map, key: &str, value: &str, condition: F) where @@ -425,7 +454,8 @@ mod tests { }; use super::{ - patch_response_with_request_metadata, restore_original_tools, rewrite_streaming_block, + build_persistence_request_body, patch_response_with_request_metadata, + restore_original_tools, rewrite_streaming_block, }; fn test_tool(name: &str) -> Tool { @@ -449,6 +479,61 @@ mod tests { } } + #[test] + fn persistence_body_keeps_previous_response_id_without_replay_expanded_input() { + let request_body = ResponsesRequest { + model: "gpt-5.4".to_string(), + input: ResponseInput::Items(vec![]), + previous_response_id: None, + store: Some(false), + ..Default::default() + }; + let client_body = ResponsesRequest { + model: "gpt-5.4".to_string(), + input: ResponseInput::Text("current turn".to_string()), + previous_response_id: Some("resp_prev".to_string()), + store: Some(true), + ..Default::default() + }; + + let persistence_body = build_persistence_request_body(&request_body, &client_body); + + assert_eq!( + persistence_body.previous_response_id.as_deref(), + Some("resp_prev") + ); + assert!(matches!( + persistence_body.input, + ResponseInput::Text(ref text) if text == "current turn" + )); + assert_eq!(persistence_body.store, Some(true)); + } + + #[test] + fn persistence_body_uses_client_input_for_root_request() { + let request_body = ResponsesRequest { + model: "gpt-5.4".to_string(), + input: ResponseInput::Items(vec![]), + store: Some(false), + ..Default::default() + }; + let client_body = ResponsesRequest { + model: "gpt-5.4".to_string(), + input: ResponseInput::Text("root turn".to_string()), + store: Some(true), + ..Default::default() + }; + + let persistence_body = build_persistence_request_body(&request_body, &client_body); + + assert!(matches!( + persistence_body.input, + ResponseInput::Text(ref text) if text == "root turn" + )); + assert_eq!(persistence_body.previous_response_id, None); + assert_eq!(persistence_body.store, Some(true)); + } + #[tokio::test] async fn restore_original_tools_strips_injected_internal_tool_when_request_had_no_tools() { let original_body = ResponsesRequest { diff --git a/model_gateway/src/routers/openai/router.rs b/model_gateway/src/routers/openai/router.rs index 0bda10e99..39f5525db 100644 --- a/model_gateway/src/routers/openai/router.rs +++ b/model_gateway/src/routers/openai/router.rs @@ -19,6 +19,7 @@ use super::{ health, provider::ProviderRegistry, responses::route::{self as responses_route, ResponsesRouterContext}, + stateful_tools::{NoOpStatefulToolBootstrapper, SharedStatefulToolBootstrapper}, }; use crate::{ app_context::AppContext, @@ -76,11 +77,18 @@ impl std::fmt::Debug for OpenAIRouter { } impl OpenAIRouter { + pub async fn new(ctx: &Arc) -> Result { + Self::new_with_stateful_tool_bootstrapper(ctx, Arc::new(NoOpStatefulToolBootstrapper)).await + } + #[expect( clippy::unused_async, reason = "async for API consistency with other router constructors" )] - pub async fn new(ctx: &Arc) -> Result { + pub async fn new_with_stateful_tool_bootstrapper( + ctx: &Arc, + stateful_tool_bootstrapper: SharedStatefulToolBootstrapper, + ) -> Result { let worker_registry = ctx.worker_registry.clone(); let mcp_orchestrator = ctx .mcp_orchestrator @@ -96,6 +104,7 @@ impl OpenAIRouter { let responses_components = Arc::new(ResponsesComponents { shared: Arc::clone(&shared_components), mcp_orchestrator: mcp_orchestrator.clone(), + stateful_tool_bootstrapper, response_storage: ctx.response_storage.clone(), conversation_storage: ctx.conversation_storage.clone(), conversation_item_storage: ctx.conversation_item_storage.clone(), diff --git a/model_gateway/src/routers/openai/stateful_tools.rs b/model_gateway/src/routers/openai/stateful_tools.rs new file mode 100644 index 000000000..937056f05 --- /dev/null +++ b/model_gateway/src/routers/openai/stateful_tools.rs @@ -0,0 +1,372 @@ +//! Generic request-scoped bootstrap support for stateful Responses tools. +//! +//! The upstream runtime only knows that some tools require prepared state +//! before the first model call. Provider-specific lifecycle logic stays behind +//! the bootstrapper implementation, and the prepared state itself is stored as +//! opaque JSON so upstream does not learn OCI container/session semantics. + +use std::{collections::BTreeSet, sync::Arc}; + +use async_trait::async_trait; +use axum::http::HeaderMap; +use openai_protocol::responses::{ + generate_id, ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseTool, + ResponsesRequest, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use smg_data_connector::RequestContext as StorageRequestContext; + +use crate::{memory::MemoryExecutionContext, middleware::TenantRequestMeta}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum StatefulToolKind { + CodeInterpreter, + Shell, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PreparedToolState { + pub kind: StatefulToolKind, + #[serde(default, skip_serializing_if = "Value::is_null")] + pub value: Value, +} + +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct StatefulToolBootstrapState { + pub executed: bool, + pub prepared_tools: Vec, +} + +impl StatefulToolBootstrapState { + pub fn prepared_tool(&self, kind: StatefulToolKind) -> Option<&Value> { + self.prepared_tools + .iter() + .find(|tool| tool.kind == kind) + .map(|tool| &tool.value) + } + + pub fn upsert_prepared_tool(&mut self, kind: StatefulToolKind, value: Value) { + if let Some(existing) = self + .prepared_tools + .iter_mut() + .find(|tool| tool.kind == kind) + { + existing.value = value; + return; + } + self.prepared_tools.push(PreparedToolState { kind, value }); + } +} + +#[derive(Debug, Default)] +pub struct StatefulToolBootstrapResult { + pub prepared_tools: Vec, + pub injected_input_items: Vec, +} + +pub struct StatefulToolBootstrapContext<'a> { + pub headers: Option<&'a HeaderMap>, + pub storage_request_context: Option<&'a StorageRequestContext>, + pub memory_execution_context: &'a MemoryExecutionContext, + pub tenant_request_meta: Option<&'a TenantRequestMeta>, +} + +#[async_trait] +pub trait StatefulToolBootstrapper: Send + Sync { + async fn bootstrap( + &self, + request: &ResponsesRequest, + context: StatefulToolBootstrapContext<'_>, + ) -> Result; +} + +pub struct NoOpStatefulToolBootstrapper; + +#[async_trait] +impl StatefulToolBootstrapper for NoOpStatefulToolBootstrapper { + async fn bootstrap( + &self, + _request: &ResponsesRequest, + _context: StatefulToolBootstrapContext<'_>, + ) -> Result { + Ok(StatefulToolBootstrapResult::default()) + } +} + +pub fn declared_stateful_tool_kinds(tools: Option<&[ResponseTool]>) -> Vec { + let mut kinds = BTreeSet::new(); + + for tool in tools.unwrap_or(&[]) { + match tool { + ResponseTool::CodeInterpreter(_) => { + kinds.insert(StatefulToolKind::CodeInterpreter); + } + ResponseTool::Shell(_) => { + kinds.insert(StatefulToolKind::Shell); + } + _ => {} + } + } + + kinds.into_iter().collect() +} + +pub fn request_has_stateful_tools(request: &ResponsesRequest) -> bool { + !declared_stateful_tool_kinds(request.tools.as_deref()).is_empty() +} + +pub async fn ensure_stateful_tool_bootstrap( + request: &mut ResponsesRequest, + bootstrap_state: &mut StatefulToolBootstrapState, + bootstrapper: &dyn StatefulToolBootstrapper, + context: StatefulToolBootstrapContext<'_>, +) -> Result<(), String> { + if bootstrap_state.executed || !request_has_stateful_tools(request) { + return Ok(()); + } + + // INVARIANT: bootstrap failures short-circuit the route immediately, so + // `executed` is only flipped after a successful injection/preparation. + let result = bootstrapper.bootstrap(request, context).await?; + prepend_injected_items(&mut request.input, result.injected_input_items); + + bootstrap_state.executed = true; + for tool in result.prepared_tools { + bootstrap_state.upsert_prepared_tool(tool.kind, tool.value); + } + + Ok(()) +} + +fn prepend_injected_items( + input: &mut ResponseInput, + mut injected_items: Vec, +) { + if injected_items.is_empty() { + return; + } + + match input { + ResponseInput::Text(text) => { + injected_items.push(ResponseInputOutputItem::Message { + id: generate_id("msg"), + role: "user".to_string(), + content: vec![ResponseContentPart::InputText { text: text.clone() }], + status: None, + phase: None, + }); + *input = ResponseInput::Items(injected_items); + } + ResponseInput::Items(existing_items) => { + injected_items.append(existing_items); + *existing_items = injected_items; + } + } +} + +pub type SharedStatefulToolBootstrapper = Arc; + +#[cfg(test)] +mod tests { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + use openai_protocol::responses::{ + CodeInterpreterTool, ResponseToolEnvironment, ShellTool, WebSearchPreviewTool, + }; + use serde_json::json; + + use super::*; + + struct CountingBootstrapper { + calls: Arc, + result: StatefulToolBootstrapResult, + } + + #[async_trait] + impl StatefulToolBootstrapper for CountingBootstrapper { + async fn bootstrap( + &self, + _request: &ResponsesRequest, + _context: StatefulToolBootstrapContext<'_>, + ) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok(StatefulToolBootstrapResult { + prepared_tools: self.result.prepared_tools.clone(), + injected_input_items: self.result.injected_input_items.clone(), + }) + } + } + + struct FailingBootstrapper; + + #[async_trait] + impl StatefulToolBootstrapper for FailingBootstrapper { + async fn bootstrap( + &self, + _request: &ResponsesRequest, + _context: StatefulToolBootstrapContext<'_>, + ) -> Result { + Err("boom".to_string()) + } + } + + fn bootstrap_context<'a>( + memory_execution_context: &'a MemoryExecutionContext, + ) -> StatefulToolBootstrapContext<'a> { + StatefulToolBootstrapContext { + headers: None, + storage_request_context: None, + memory_execution_context, + tenant_request_meta: None, + } + } + + #[test] + fn declared_stateful_tool_kinds_deduplicates_supported_tools() { + let tools = vec![ + ResponseTool::WebSearchPreview(WebSearchPreviewTool::default()), + ResponseTool::Shell(ShellTool::default()), + ResponseTool::CodeInterpreter(CodeInterpreterTool { + container: None, + environment: Some(ResponseToolEnvironment::default()), + }), + ResponseTool::Shell(ShellTool::default()), + ]; + + assert_eq!( + declared_stateful_tool_kinds(Some(&tools)), + vec![StatefulToolKind::CodeInterpreter, StatefulToolKind::Shell] + ); + } + + #[tokio::test] + async fn ensure_stateful_tool_bootstrap_skips_requests_without_stateful_tools() { + let calls = Arc::new(AtomicUsize::new(0)); + let bootstrapper = CountingBootstrapper { + calls: Arc::clone(&calls), + result: StatefulToolBootstrapResult::default(), + }; + let memory_execution_context = MemoryExecutionContext::default(); + let mut request = ResponsesRequest { + input: ResponseInput::Text("hello".to_string()), + tools: Some(vec![ResponseTool::WebSearchPreview( + WebSearchPreviewTool::default(), + )]), + ..Default::default() + }; + let mut bootstrap_state = StatefulToolBootstrapState::default(); + + ensure_stateful_tool_bootstrap( + &mut request, + &mut bootstrap_state, + &bootstrapper, + bootstrap_context(&memory_execution_context), + ) + .await + .expect("bootstrap should succeed"); + + assert_eq!(calls.load(Ordering::SeqCst), 0); + assert!(!bootstrap_state.executed); + } + + #[tokio::test] + async fn ensure_stateful_tool_bootstrap_injects_context_and_runs_once() { + let calls = Arc::new(AtomicUsize::new(0)); + let bootstrapper = CountingBootstrapper { + calls: Arc::clone(&calls), + result: StatefulToolBootstrapResult { + prepared_tools: vec![PreparedToolState { + kind: StatefulToolKind::Shell, + value: json!({"session_id": "sess_123"}), + }], + injected_input_items: vec![ResponseInputOutputItem::Message { + id: "msg_bootstrap".to_string(), + role: "developer".to_string(), + content: vec![ResponseContentPart::InputText { + text: "Resolved shell session is available.".to_string(), + }], + status: Some("completed".to_string()), + phase: None, + }], + }, + }; + let memory_execution_context = MemoryExecutionContext::default(); + let mut request = ResponsesRequest { + input: ResponseInput::Text("hello".to_string()), + tools: Some(vec![ResponseTool::Shell(ShellTool::default())]), + ..Default::default() + }; + let mut bootstrap_state = StatefulToolBootstrapState::default(); + + ensure_stateful_tool_bootstrap( + &mut request, + &mut bootstrap_state, + &bootstrapper, + bootstrap_context(&memory_execution_context), + ) + .await + .expect("first bootstrap should succeed"); + ensure_stateful_tool_bootstrap( + &mut request, + &mut bootstrap_state, + &bootstrapper, + bootstrap_context(&memory_execution_context), + ) + .await + .expect("second bootstrap should be a no-op"); + + assert_eq!(calls.load(Ordering::SeqCst), 1); + assert!(bootstrap_state.executed); + assert_eq!( + bootstrap_state.prepared_tool(StatefulToolKind::Shell), + Some(&json!({"session_id": "sess_123"})) + ); + + let ResponseInput::Items(items) = &request.input else { + panic!("bootstrap should normalize text input to items"); + }; + assert_eq!(items.len(), 2); + assert!(matches!( + &items[0], + ResponseInputOutputItem::Message { role, .. } if role == "developer" + )); + assert!(matches!( + &items[1], + ResponseInputOutputItem::Message { role, .. } if role == "user" + )); + } + + #[tokio::test] + async fn ensure_stateful_tool_bootstrap_error_is_non_mutating() { + let memory_execution_context = MemoryExecutionContext::default(); + let mut request = ResponsesRequest { + input: ResponseInput::Text("hello".to_string()), + tools: Some(vec![ResponseTool::Shell(ShellTool::default())]), + ..Default::default() + }; + let original_input = serde_json::to_value(&request.input).expect("serialize input"); + let mut bootstrap_state = StatefulToolBootstrapState::default(); + + let err = ensure_stateful_tool_bootstrap( + &mut request, + &mut bootstrap_state, + &FailingBootstrapper, + bootstrap_context(&memory_execution_context), + ) + .await + .expect_err("bootstrap should fail"); + + assert_eq!(err, "boom"); + assert!(!bootstrap_state.executed); + assert_eq!( + serde_json::to_value(&request.input).expect("serialize input"), + original_input + ); + assert!(bootstrap_state.prepared_tools.is_empty()); + } +} diff --git a/model_gateway/tests/api/responses_api_test.rs b/model_gateway/tests/api/responses_api_test.rs index 345fe2a5a..4e7cbef44 100644 --- a/model_gateway/tests/api/responses_api_test.rs +++ b/model_gateway/tests/api/responses_api_test.rs @@ -1,23 +1,43 @@ // Integration test for Responses API +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use async_trait::async_trait; use axum::http::{HeaderMap, HeaderValue, StatusCode}; use openai_protocol::{ common::{GenerationRequest, UsageInfo}, responses::{ CodeInterpreterTool, McpTool, ReasoningEffort, RequireApproval, RequireApprovalMode, - ResponseInput, ResponseReasoningParam, ResponseTool, ResponsesRequest, ResponsesToolChoice, - ServiceTier, ToolChoiceOptions, Truncation, WebSearchPreviewTool, + ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseReasoningParam, + ResponseTool, ResponsesRequest, ResponsesToolChoice, ServiceTier, ShellTool, + ToolChoiceOptions, Truncation, WebSearchPreviewTool, }, }; use smg::{ config::RouterConfig, - routers::{conversations, RouterFactory}, + routers::{ + conversations, + openai::{ + stateful_tools::{ + PreparedToolState, StatefulToolBootstrapContext, StatefulToolBootstrapResult, + StatefulToolBootstrapper, StatefulToolKind, + }, + OpenAIRouter, + }, + RouterFactory, RouterTrait, + }, tenant::{RouteRequestMeta, TenantKey}, }; use crate::common::{ mock_mcp_server::{MockFailingMCPServer, MockMCPServer}, - mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}, + mock_worker::{ + take_recorded_responses_requests_for_port, HealthStatus, MockWorker, MockWorkerConfig, + WorkerType, + }, }; const TEST_INTERNAL_MCP_SERVER_LABEL: &str = "internal-mock"; @@ -27,6 +47,518 @@ fn test_tenant_meta() -> smg::middleware::TenantRequestMeta { RouteRequestMeta::new(TenantKey::from("test-tenant")) } +struct TestBootstrapper { + calls: Arc, + result: StatefulToolBootstrapResult, + error: Option, +} + +#[async_trait] +impl StatefulToolBootstrapper for TestBootstrapper { + async fn bootstrap( + &self, + _request: &ResponsesRequest, + _context: StatefulToolBootstrapContext<'_>, + ) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + if let Some(error) = &self.error { + return Err(error.clone()); + } + + Ok(StatefulToolBootstrapResult { + prepared_tools: self.result.prepared_tools.clone(), + injected_input_items: self.result.injected_input_items.clone(), + }) + } +} + +fn bootstrap_injected_message(text: &str) -> ResponseInputOutputItem { + ResponseInputOutputItem::Message { + id: openai_protocol::responses::generate_id("msg"), + role: "system".to_string(), + content: vec![ResponseContentPart::InputText { + text: text.to_string(), + }], + status: None, + phase: None, + } +} + +fn stateful_tool_request(stream: bool) -> ResponsesRequest { + ResponsesRequest { + background: Some(false), + include: None, + input: ResponseInput::Text("Run the stateful tool".to_string()), + instructions: Some("Use the tools if needed.".to_string()), + max_output_tokens: Some(64), + max_tool_calls: None, + metadata: None, + model: "mock-model".to_string(), + parallel_tool_calls: Some(true), + previous_response_id: None, + reasoning: None, + service_tier: Some(ServiceTier::Auto), + store: Some(true), + stream: Some(stream), + temperature: Some(0.2), + tool_choice: Some(ResponsesToolChoice::default()), + tools: Some(vec![ + ResponseTool::Shell(ShellTool::default()), + ResponseTool::CodeInterpreter(CodeInterpreterTool::default()), + ]), + top_logprobs: Some(0), + top_p: None, + truncation: Some(Truncation::Disabled), + text: None, + user: None, + request_id: Some("resp_stateful_tool_bootstrap".to_string()), + priority: 0, + frequency_penalty: Some(0.0), + presence_penalty: Some(0.0), + stop: None, + prompt: None, + prompt_cache_key: None, + prompt_cache_retention: None, + safety_identifier: None, + stream_options: None, + context_management: None, + top_k: -1, + min_p: 0.0, + repetition_penalty: 1.0, + conversation: None, + } +} + +fn non_stateful_tool_request() -> ResponsesRequest { + ResponsesRequest { + tools: Some(vec![ResponseTool::WebSearchPreview( + WebSearchPreviewTool::default(), + )]), + ..stateful_tool_request(false) + } +} + +fn stateful_tool_request_with_mcp(stream: bool, mcp_url: String) -> ResponsesRequest { + ResponsesRequest { + tools: Some(vec![ + ResponseTool::Shell(ShellTool::default()), + ResponseTool::Mcp(McpTool { + server_url: Some(mcp_url), + authorization: None, + headers: None, + server_label: "mock".to_string(), + server_description: None, + require_approval: Some(RequireApproval::Mode(RequireApprovalMode::Never)), + allowed_tools: None, + connector_id: None, + defer_loading: None, + }), + ]), + ..stateful_tool_request(stream) + } +} + +fn bootstrap_payload_texts(payload: &serde_json::Value) -> Vec { + payload + .get("input") + .and_then(|value| value.as_array()) + .into_iter() + .flatten() + .filter_map(|item| item.get("content").and_then(|value| value.as_array())) + .flat_map(|parts| parts.iter()) + .filter_map(|part| part.get("text").and_then(|value| value.as_str())) + .map(ToString::to_string) + .collect() +} + +fn assert_resume_payload_preserves_bootstrap_text( + payload: &serde_json::Value, + bootstrap_text: &str, + message: &str, +) { + let texts = bootstrap_payload_texts(payload); + assert_eq!( + texts.first().map(String::as_str), + Some(bootstrap_text), + "{message}: bootstrap text should remain first; texts={texts:?}" + ); + assert!( + texts.iter().any(|text| text == "Run the stateful tool"), + "{message}: current user text should be present; texts={texts:?}" + ); +} + +fn standard_openai_router_config(worker_url: String) -> RouterConfig { + RouterConfig::builder() + .openai_mode(vec![worker_url]) + .random_policy() + .host("127.0.0.1") + .port(0) + .max_payload_size(8 * 1024 * 1024) + .request_timeout_secs(60) + .worker_startup_timeout_secs(5) + .worker_startup_check_interval_secs(1) + .log_level("warn") + .max_concurrent_requests(32) + .queue_timeout_secs(5) + .build_unchecked() +} + +#[tokio::test] +async fn test_non_streaming_stateful_tool_bootstrap_injects_payload_before_worker_call() { + let mut worker = MockWorker::new(MockWorkerConfig { + port: 0, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let worker_url = worker.start().await.expect("start worker"); + let worker_port = worker.port().await; + + let ctx = crate::common::create_test_context(standard_openai_router_config(worker_url)).await; + let calls = Arc::new(AtomicUsize::new(0)); + let bootstrapper = Arc::new(TestBootstrapper { + calls: Arc::clone(&calls), + result: StatefulToolBootstrapResult { + prepared_tools: vec![ + PreparedToolState { + kind: StatefulToolKind::Shell, + value: serde_json::json!({ "container_id": "ctr-shell" }), + }, + PreparedToolState { + kind: StatefulToolKind::CodeInterpreter, + value: serde_json::json!({ "container_id": "ctr-python" }), + }, + ], + injected_input_items: vec![bootstrap_injected_message("bootstrap context")], + }, + error: None, + }); + let router = OpenAIRouter::new_with_stateful_tool_bootstrapper(&ctx, bootstrapper) + .await + .expect("router"); + + let req = stateful_tool_request(false); + let tenant_meta = test_tenant_meta(); + let resp = router + .route_responses(None, &tenant_meta, &req, req.model.as_str()) + .await; + + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(calls.load(Ordering::SeqCst), 1); + + let requests = take_recorded_responses_requests_for_port(worker_port); + assert_eq!(requests.len(), 1, "expected one downstream worker call"); + + let payload = &requests[0]; + assert_eq!(payload["store"], false); + assert_eq!( + bootstrap_payload_texts(payload), + vec![ + "bootstrap context".to_string(), + "Run the stateful tool".to_string(), + ] + ); + + worker.stop().await; +} + +#[tokio::test] +async fn test_streaming_stateful_tool_bootstrap_injects_payload_before_worker_call() { + let mut worker = MockWorker::new(MockWorkerConfig { + port: 0, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let worker_url = worker.start().await.expect("start worker"); + let worker_port = worker.port().await; + + let ctx = crate::common::create_test_context(standard_openai_router_config(worker_url)).await; + let calls = Arc::new(AtomicUsize::new(0)); + let bootstrapper = Arc::new(TestBootstrapper { + calls: Arc::clone(&calls), + result: StatefulToolBootstrapResult { + prepared_tools: vec![PreparedToolState { + kind: StatefulToolKind::Shell, + value: serde_json::json!({ "container_id": "ctr-shell" }), + }], + injected_input_items: vec![bootstrap_injected_message("stream bootstrap context")], + }, + error: None, + }); + let router = OpenAIRouter::new_with_stateful_tool_bootstrapper(&ctx, bootstrapper) + .await + .expect("router"); + + let req = stateful_tool_request(true); + let tenant_meta = test_tenant_meta(); + let resp = router + .route_responses(None, &tenant_meta, &req, req.model.as_str()) + .await; + + assert_eq!(resp.status(), StatusCode::OK); + let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .expect("read streaming body"); + let body_text = String::from_utf8_lossy(&body_bytes); + assert!(body_text.contains("[DONE]")); + assert_eq!(calls.load(Ordering::SeqCst), 1); + + let requests = take_recorded_responses_requests_for_port(worker_port); + assert_eq!(requests.len(), 1, "expected one downstream worker call"); + assert_eq!( + bootstrap_payload_texts(&requests[0]), + vec![ + "stream bootstrap context".to_string(), + "Run the stateful tool".to_string(), + ] + ); + + worker.stop().await; +} + +#[tokio::test] +async fn test_router_skips_stateful_tool_bootstrap_when_request_has_no_stateful_tools() { + let mut worker = MockWorker::new(MockWorkerConfig { + port: 0, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let worker_url = worker.start().await.expect("start worker"); + let worker_port = worker.port().await; + + let ctx = crate::common::create_test_context(standard_openai_router_config(worker_url)).await; + let calls = Arc::new(AtomicUsize::new(0)); + let bootstrapper = Arc::new(TestBootstrapper { + calls: Arc::clone(&calls), + result: StatefulToolBootstrapResult { + prepared_tools: Vec::new(), + injected_input_items: vec![bootstrap_injected_message("should not be injected")], + }, + error: None, + }); + let router = OpenAIRouter::new_with_stateful_tool_bootstrapper(&ctx, bootstrapper) + .await + .expect("router"); + + let req = non_stateful_tool_request(); + let tenant_meta = test_tenant_meta(); + let resp = router + .route_responses(None, &tenant_meta, &req, req.model.as_str()) + .await; + + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(calls.load(Ordering::SeqCst), 0); + + let requests = take_recorded_responses_requests_for_port(worker_port); + assert_eq!(requests.len(), 1, "expected one downstream worker call"); + assert_eq!( + requests[0]["input"], + serde_json::json!("Run the stateful tool") + ); + + worker.stop().await; +} + +#[tokio::test] +async fn test_stateful_tool_bootstrap_failure_short_circuits_before_worker_call() { + let mut worker = MockWorker::new(MockWorkerConfig { + port: 0, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let worker_url = worker.start().await.expect("start worker"); + let worker_port = worker.port().await; + + let ctx = crate::common::create_test_context(standard_openai_router_config(worker_url)).await; + let calls = Arc::new(AtomicUsize::new(0)); + let bootstrapper = Arc::new(TestBootstrapper { + calls: Arc::clone(&calls), + result: StatefulToolBootstrapResult::default(), + error: Some("bootstrap failed".to_string()), + }); + let router = OpenAIRouter::new_with_stateful_tool_bootstrapper(&ctx, bootstrapper) + .await + .expect("router"); + + let req = stateful_tool_request(false); + let tenant_meta = test_tenant_meta(); + let resp = router + .route_responses(None, &tenant_meta, &req, req.model.as_str()) + .await; + + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(calls.load(Ordering::SeqCst), 1); + assert!( + take_recorded_responses_requests_for_port(worker_port).is_empty(), + "worker should not receive a request when bootstrap fails" + ); + + let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .expect("read error body"); + let body_json: serde_json::Value = + serde_json::from_slice(&body_bytes).expect("parse error body"); + assert_eq!(body_json["error"]["code"], "stateful_tool_bootstrap_failed"); + + worker.stop().await; +} + +#[tokio::test] +async fn test_non_streaming_tool_loop_preserves_bootstrapped_input_on_resume() { + let mut mcp = MockMCPServer::start().await.expect("start mcp"); + let mut worker = MockWorker::new(MockWorkerConfig { + port: 0, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let worker_url = worker.start().await.expect("start worker"); + let worker_port = worker.port().await; + + let ctx = crate::common::create_test_context(standard_openai_router_config(worker_url)).await; + let calls = Arc::new(AtomicUsize::new(0)); + let bootstrapper = Arc::new(TestBootstrapper { + calls: Arc::clone(&calls), + result: StatefulToolBootstrapResult { + prepared_tools: vec![PreparedToolState { + kind: StatefulToolKind::Shell, + value: serde_json::json!({ "container_id": "ctr-shell" }), + }], + injected_input_items: vec![bootstrap_injected_message("bootstrap context")], + }, + error: None, + }); + let router = OpenAIRouter::new_with_stateful_tool_bootstrapper(&ctx, bootstrapper) + .await + .expect("router"); + + let req = stateful_tool_request_with_mcp(false, mcp.url()); + let tenant_meta = test_tenant_meta(); + let resp = router + .route_responses(None, &tenant_meta, &req, req.model.as_str()) + .await; + + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(calls.load(Ordering::SeqCst), 1); + + let requests = take_recorded_responses_requests_for_port(worker_port); + assert_eq!( + requests.len(), + 2, + "expected initial and resumed downstream worker calls" + ); + + assert_eq!( + bootstrap_payload_texts(&requests[0]), + vec![ + "bootstrap context".to_string(), + "Run the stateful tool".to_string(), + ] + ); + assert_resume_payload_preserves_bootstrap_text( + &requests[1], + "bootstrap context", + "resume payload should preserve bootstrapped input items", + ); + assert!( + requests[1] + .get("input") + .and_then(|value| value.as_array()) + .is_some_and(|items| items.iter().any(|item| { + item.get("type").and_then(|value| value.as_str()) == Some("function_call_output") + })), + "resume payload should include function_call_output conversation history" + ); + + worker.stop().await; + mcp.stop().await; +} + +#[tokio::test] +async fn test_streaming_tool_loop_preserves_bootstrapped_input_on_resume() { + let mut mcp = MockMCPServer::start().await.expect("start mcp"); + let mut worker = MockWorker::new(MockWorkerConfig { + port: 0, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let worker_url = worker.start().await.expect("start worker"); + let worker_port = worker.port().await; + + let ctx = crate::common::create_test_context(standard_openai_router_config(worker_url)).await; + let calls = Arc::new(AtomicUsize::new(0)); + let bootstrapper = Arc::new(TestBootstrapper { + calls: Arc::clone(&calls), + result: StatefulToolBootstrapResult { + prepared_tools: vec![PreparedToolState { + kind: StatefulToolKind::Shell, + value: serde_json::json!({ "container_id": "ctr-shell" }), + }], + injected_input_items: vec![bootstrap_injected_message("stream bootstrap context")], + }, + error: None, + }); + let router = OpenAIRouter::new_with_stateful_tool_bootstrapper(&ctx, bootstrapper) + .await + .expect("router"); + + let req = stateful_tool_request_with_mcp(true, mcp.url()); + let tenant_meta = test_tenant_meta(); + let resp = router + .route_responses(None, &tenant_meta, &req, req.model.as_str()) + .await; + + assert_eq!(resp.status(), StatusCode::OK); + let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .expect("read streaming body"); + let body_text = String::from_utf8_lossy(&body_bytes); + assert!(body_text.contains("[DONE]")); + assert_eq!(calls.load(Ordering::SeqCst), 1); + + let requests = take_recorded_responses_requests_for_port(worker_port); + assert_eq!( + requests.len(), + 2, + "expected initial and resumed downstream worker calls" + ); + assert_eq!( + bootstrap_payload_texts(&requests[0]), + vec![ + "stream bootstrap context".to_string(), + "Run the stateful tool".to_string(), + ] + ); + assert_resume_payload_preserves_bootstrap_text( + &requests[1], + "stream bootstrap context", + "streaming resume payload should preserve bootstrapped input items", + ); + assert!( + requests[1] + .get("input") + .and_then(|value| value.as_array()) + .is_some_and(|items| items.iter().any(|item| { + item.get("type").and_then(|value| value.as_str()) == Some("function_call_output") + })), + "streaming resume payload should include function_call_output conversation history" + ); + + worker.stop().await; + mcp.stop().await; +} + #[tokio::test] async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { // Start mock MCP server @@ -1591,7 +2123,7 @@ async fn test_max_tool_calls_limit() { async fn setup_streaming_mcp_test() -> ( MockMCPServer, MockWorker, - Box, + Box, tempfile::TempDir, ) { let mcp = MockMCPServer::start().await.expect("start mcp"); diff --git a/model_gateway/tests/common/mock_worker.rs b/model_gateway/tests/common/mock_worker.rs index 44456ffc8..3963748b0 100755 --- a/model_gateway/tests/common/mock_worker.rs +++ b/model_gateway/tests/common/mock_worker.rs @@ -83,6 +83,7 @@ impl MockWorker { } else { port }; + clear_recorded_responses_requests_for_port(port); let app = Router::new() .route("/health", get(health_handler)) @@ -138,6 +139,9 @@ impl MockWorker { /// Stop the mock worker server pub async fn stop(&mut self) { + let port = self.config.read().await.port; + clear_recorded_responses_requests_for_port(port); + if let Some(shutdown_tx) = self.shutdown_tx.take() { let _ = shutdown_tx.send(()); } @@ -147,10 +151,18 @@ impl MockWorker { let _ = tokio::time::timeout(tokio::time::Duration::from_secs(5), handle).await; } } + + pub async fn port(&self) -> u16 { + self.config.read().await.port + } } impl Drop for MockWorker { fn drop(&mut self) { + if let Ok(config) = self.config.try_read() { + clear_recorded_responses_requests_for_port(config.port); + } + // Clean shutdown when dropped if let Some(shutdown_tx) = self.shutdown_tx.take() { let _ = shutdown_tx.send(()); @@ -594,6 +606,7 @@ async fn responses_handler( Json(payload): Json, ) -> Response { let config = config.read().await; + record_responses_request_for_port(config.port, payload.clone()); if should_fail(&config) { return ( @@ -1232,11 +1245,17 @@ async fn responses_cancel_handler( // --- Simple in-memory response store per worker port (for tests) --- static RESP_STORE: OnceLock>>> = OnceLock::new(); +static RESPONSES_REQUEST_STORE: OnceLock>>> = + OnceLock::new(); fn get_store() -> &'static Mutex>> { RESP_STORE.get_or_init(|| Mutex::new(HashMap::new())) } +fn get_responses_request_store() -> &'static Mutex>> { + RESPONSES_REQUEST_STORE.get_or_init(|| Mutex::new(HashMap::new())) +} + #[expect( clippy::unwrap_used, reason = "test helper - panicking on failure is intentional" @@ -1257,6 +1276,33 @@ fn response_exists_for_port(port: u16, response_id: &str) -> bool { .unwrap_or(false) } +#[expect( + clippy::unwrap_used, + reason = "test helper - panicking on failure is intentional" +)] +fn record_responses_request_for_port(port: u16, payload: serde_json::Value) { + let mut map = get_responses_request_store().lock().unwrap(); + map.entry(port).or_default().push(payload); +} + +#[expect( + clippy::unwrap_used, + reason = "test helper - panicking on failure is intentional" +)] +pub fn take_recorded_responses_requests_for_port(port: u16) -> Vec { + let mut map = get_responses_request_store().lock().unwrap(); + map.remove(&port).unwrap_or_default() +} + +#[expect( + clippy::unwrap_used, + reason = "test helper - panicking on failure is intentional" +)] +fn clear_recorded_responses_requests_for_port(port: u16) { + let mut map = get_responses_request_store().lock().unwrap(); + map.remove(&port); +} + // Minimal rerank handler returning mock results; router shapes final response #[expect( clippy::unwrap_used, diff --git a/model_gateway/tests/routing/test_openai_routing.rs b/model_gateway/tests/routing/test_openai_routing.rs index 0f66d3c5a..319b496c6 100644 --- a/model_gateway/tests/routing/test_openai_routing.rs +++ b/model_gateway/tests/routing/test_openai_routing.rs @@ -539,7 +539,10 @@ async fn test_openai_router_responses_streaming_with_mock() { sleep(Duration::from_millis(10)).await; }; - // Input is now stored as a JSON array of items + // Persistence stores only the caller turn for previous_response_id requests. + // The replay chain is retained by stored.previous_response_id and reloaded + // from response storage on the next turn; storing replay-expanded input + // here would duplicate history every hop. assert!(stored.input.is_array()); let input_items = stored.input.as_array().unwrap(); assert_eq!(input_items.len(), 1);