Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions model_gateway/src/routers/openai/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use smg_data_connector::{
};
use smg_mcp::{McpOrchestrator, McpToolSession};

use super::provider::Provider;
use super::{
provider::Provider,
stateful_tools::{SharedStatefulToolBootstrapper, StatefulToolBootstrapState},
};
use crate::{
config::RouterConfig, memory::MemoryExecutionContext, middleware,
middleware::TenantRequestMeta, worker::Worker,
Expand Down Expand Up @@ -47,6 +50,7 @@ pub struct SharedComponents {
pub struct ResponsesComponents {
pub shared: Arc<SharedComponents>,
pub mcp_orchestrator: Arc<McpOrchestrator>,
pub stateful_tool_bootstrapper: SharedStatefulToolBootstrapper,
pub response_storage: Arc<dyn ResponseStorage>,
pub conversation_storage: Arc<dyn ConversationStorage>,
pub conversation_item_storage: Arc<dyn ConversationItemStorage>,
Expand Down Expand Up @@ -88,6 +92,13 @@ impl ComponentRefs {
}
}

pub fn stateful_tool_bootstrapper(&self) -> Option<&SharedStatefulToolBootstrapper> {
match self {
ComponentRefs::Shared(_) => None,
ComponentRefs::Responses(r) => Some(&r.stateful_tool_bootstrapper),
}
}

pub fn conversation_storage(&self) -> Option<&Arc<dyn ConversationStorage>> {
match self {
ComponentRefs::Shared(_) => None,
Expand Down Expand Up @@ -130,8 +141,10 @@ pub struct PayloadState {

#[derive(Default)]
pub struct ResponsesPayloadState {
pub client_request: Option<Arc<ResponsesRequest>>,
pub previous_response_id: Option<String>,
pub existing_mcp_list_tools_labels: Vec<String>,
pub stateful_tool_bootstrap: StatefulToolBootstrapState,
}

impl RequestContext {
Expand Down Expand Up @@ -265,20 +278,24 @@ pub struct StorageHandles {
pub struct OwnedStreamingContext {
pub url: String,
pub payload: Value,
pub original_body: ResponsesRequest,
pub request_body: Arc<ResponsesRequest>,
pub client_body: Arc<ResponsesRequest>,
pub previous_response_id: Option<String>,
pub existing_mcp_list_tools_labels: Vec<String>,
pub stateful_tool_bootstrap: StatefulToolBootstrapState,
pub storage: StorageHandles,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

impl RequestContext {
pub fn into_streaming_context(mut self) -> Result<OwnedStreamingContext, &'static str> {
let payload_state = self.take_payload().ok_or("Payload not prepared")?;
let responses_payload_state = self.take_responses_payload().unwrap_or_default();
let original_body = self
.responses_request()
.ok_or("Expected responses request")?
.clone();
let request_body = self
.responses_request_arc()
.ok_or("Expected responses request")?;
let client_body = responses_payload_state
.client_request
.unwrap_or_else(|| Arc::clone(&request_body));
let response = self
.components
.response_storage()
Expand All @@ -303,9 +320,11 @@ impl RequestContext {
Ok(OwnedStreamingContext {
url: payload_state.url,
payload: payload_state.json,
original_body,
request_body,
client_body,
previous_response_id: responses_payload_state.previous_response_id,
existing_mcp_list_tools_labels: responses_payload_state.existing_mcp_list_tools_labels,
stateful_tool_bootstrap: responses_payload_state.stateful_tool_bootstrap,
storage: StorageHandles {
response,
conversation,
Expand Down
28 changes: 25 additions & 3 deletions model_gateway/src/routers/openai/mcp/tool_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::{
mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS},
},
error,
openai::stateful_tools::StatefulToolBootstrapState,
},
};

Expand All @@ -49,14 +50,29 @@ pub(crate) struct ToolLoopState {
pub conversation_history: Vec<Value>,
/// Original user input (preserved for building resume payloads)
pub original_input: ResponseInput,
/// Request-scoped prepared state for hosted/stateful tools.
pub stateful_tool_bootstrap: StatefulToolBootstrapState,
/// MCP bindings already represented by historical `mcp_list_tools` items.
pub existing_mcp_list_tools_labels: HashSet<String>,
/// Transformed output items (mcp_call, web_search_call, etc.) - stored to avoid reconstruction
pub mcp_call_items: Vec<Value>,
}

impl ToolLoopState {
#[cfg(test)]
pub fn new(original_input: ResponseInput, prior_mcp_list_tools_labels: Vec<String>) -> Self {
Self::new_with_bootstrap(
original_input,
prior_mcp_list_tools_labels,
StatefulToolBootstrapState::default(),
)
}

pub fn new_with_bootstrap(
original_input: ResponseInput,
prior_mcp_list_tools_labels: Vec<String>,
stateful_tool_bootstrap: StatefulToolBootstrapState,
) -> Self {
Comment thread
RohanSogani marked this conversation as resolved.
let known_labels = prior_mcp_list_tools_labels
.into_iter()
.collect::<HashSet<_>>();
Expand All @@ -66,6 +82,7 @@ impl ToolLoopState {
total_calls: 0,
conversation_history: Vec::new(),
original_input,
stateful_tool_bootstrap,
existing_mcp_list_tools_labels: known_labels,
mcp_call_items: Vec::new(),
}
Expand Down Expand Up @@ -795,6 +812,7 @@ fn approval_prefix_items(
pub(crate) struct ToolLoopExecutionContext<'a> {
pub original_body: &'a ResponsesRequest,
pub existing_mcp_list_tools_labels: &'a [String],
pub stateful_tool_bootstrap: &'a StatefulToolBootstrapState,
pub session: &'a McpToolSession<'a>,
}

Expand All @@ -810,21 +828,25 @@ pub(crate) async fn execute_tool_loop(
let ToolLoopExecutionContext {
original_body,
existing_mcp_list_tools_labels,
stateful_tool_bootstrap,
session,
} = tool_loop_ctx;

let mut state = ToolLoopState::new(
let mut state = ToolLoopState::new_with_bootstrap(
original_body.input.clone(),
existing_mcp_list_tools_labels.to_vec(),
stateful_tool_bootstrap.clone(),
);
let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize);
let base_payload = initial_payload.clone();
let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([]));
let mut current_payload = initial_payload;

info!(
"Starting tool loop: max_tool_calls={:?}, max_iterations={}",
max_tool_calls, DEFAULT_MAX_ITERATIONS
"Starting tool loop: max_tool_calls={:?}, max_iterations={}, prepared_stateful_tools={}",
max_tool_calls,
DEFAULT_MAX_ITERATIONS,
state.stateful_tool_bootstrap.prepared_tools.len()
);
let provider = ApiProvider::from_url(url);
let auth_header = provider.extract_auth_header(headers, worker_api_key);
Expand Down
1 change: 1 addition & 0 deletions model_gateway/src/routers/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ mod provider;
pub mod realtime;
pub mod responses;
mod router;
pub mod stateful_tools;

pub use router::OpenAIRouter;
35 changes: 33 additions & 2 deletions model_gateway/src/routers/openai/responses/history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ fn deserialize_items_from_array(array: &Value) -> Vec<ResponseInputOutputItem> {
.map_err(|e| warn!("Failed to deserialize item: {}. Item: {}", e, item))
.ok()
})
.filter(|item| !matches!(item, ResponseInputOutputItem::Reasoning { .. }))
.collect()
})
.unwrap_or_default()
Expand Down Expand Up @@ -321,9 +322,10 @@ pub(crate) fn inject_memory_context(

#[cfg(test)]
mod tests {
use openai_protocol::responses::{ResponseInput, ResponsesRequest};
use openai_protocol::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest};
use serde_json::json;

use super::inject_memory_context;
use super::{deserialize_items_from_array, inject_memory_context};
use crate::routers::common::header_utils::{
ConversationMemoryConfig, LongTermMemoryConfig, ShortTermMemoryConfig,
};
Expand Down Expand Up @@ -357,4 +359,33 @@ mod tests {
}
}
}

#[test]
fn deserialize_items_from_array_drops_reasoning_replay_items() {
let reasoning = json!({
"type": "reasoning",
"id": "rs_1",
"summary": []
});
assert!(
serde_json::from_value::<ResponseInputOutputItem>(reasoning.clone()).is_ok(),
"reasoning fixture must deserialize so this test validates filtering logic"
);

let items = json!([
{
"type": "message",
"id": "msg_1",
"role": "user",
"content": [{ "type": "input_text", "text": "hello" }],
"status": "completed"
},
reasoning
]);

let parsed = deserialize_items_from_array(&items);

assert_eq!(parsed.len(), 1);
assert!(matches!(parsed[0], ResponseInputOutputItem::Message { .. }));
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
27 changes: 17 additions & 10 deletions model_gateway/src/routers/openai/responses/non_streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use serde_json::Value;
use smg_mcp::McpToolSession;
use tracing::warn;

use super::utils::{patch_response_with_request_metadata, restore_original_tools};
use super::utils::{
build_persistence_request_body, patch_response_with_request_metadata, restore_original_tools,
};
use crate::routers::{
common::{
header_utils::{extract_forwardable_request_headers, ApiProvider},
Expand Down Expand Up @@ -39,16 +41,19 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
url,
} = payload_state;
let ResponsesPayloadState {
client_request,
previous_response_id,
existing_mcp_list_tools_labels,
stateful_tool_bootstrap,
} = ctx.take_responses_payload().unwrap_or_default();
Comment thread
coderabbitai[bot] marked this conversation as resolved.

let original_body = match ctx.responses_request() {
let request_body = match ctx.responses_request() {
Some(r) => r,
None => {
return error::internal_error("internal_error", "Expected responses request");
}
};
let client_body = client_request.as_deref().unwrap_or(request_body);
let worker = match ctx.worker() {
Some(w) => w.clone(),
None => {
Expand All @@ -63,7 +68,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
};

// Check for MCP tools and create session if needed
let mcp_servers = if let Some(tools) = original_body.tools.as_deref() {
let mcp_servers = if let Some(tools) = request_body.tools.as_deref() {
ensure_request_mcp_client(mcp_orchestrator, tools).await
} else {
None
Expand All @@ -72,7 +77,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
let mut response_json: Value;

if let Some(mcp_servers) = mcp_servers {
let session_request_id = original_body
let session_request_id = request_body
.request_id
.clone()
.unwrap_or_else(|| format!("req_{}", uuid::Uuid::now_v7()));
Expand All @@ -83,7 +88,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
&session_request_id,
forwarded_headers,
);
if let Some(tools) = original_body.tools.as_deref() {
if let Some(tools) = request_body.tools.as_deref() {
session.configure_response_tools_approval(tools);
}
prepare_mcp_tools_as_functions(&mut payload, &session);
Expand All @@ -95,8 +100,9 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
worker.api_key(),
payload,
ToolLoopExecutionContext {
original_body,
original_body: request_body,
existing_mcp_list_tools_labels: &existing_mcp_list_tools_labels,
stateful_tool_bootstrap: &stateful_tool_bootstrap,
session: &session,
},
)
Expand All @@ -112,7 +118,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
}
}

restore_original_tools(&mut response_json, original_body, Some(&session));
restore_original_tools(&mut response_json, client_body, Some(&session));
} else {
let mut request_builder = ctx.components.client().post(&url).json(&payload);
let provider = ApiProvider::from_url(&url);
Expand Down Expand Up @@ -156,11 +162,11 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
}
};

restore_original_tools(&mut response_json, original_body, None);
restore_original_tools(&mut response_json, client_body, None);
}
patch_response_with_request_metadata(
&mut response_json,
original_body,
client_body,
previous_response_id.as_deref(),
);

Expand All @@ -169,12 +175,13 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
ctx.components.conversation_item_storage(),
ctx.components.response_storage(),
) {
let persistence_body = build_persistence_request_body(request_body, client_body);
if let Err(err) = persist_conversation_items(
conv_storage.clone(),
item_storage.clone(),
resp_storage.clone(),
&response_json,
original_body,
&persistence_body,
ctx.storage_request_context.clone(),
)
.await
Expand Down
Loading
Loading