Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion model_gateway/src/routers/openai/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use smg_data_connector::{
};
use smg_mcp::{McpOrchestrator, McpToolSession};

use super::provider::Provider;
use super::{
provider::Provider,
stateful_tools::{SharedStatefulToolBootstrapper, StatefulToolBootstrapState},
};
use crate::{
config::RouterConfig, memory::MemoryExecutionContext, middleware,
middleware::TenantRequestMeta, worker::Worker,
Expand Down Expand Up @@ -47,6 +50,7 @@ pub struct SharedComponents {
pub struct ResponsesComponents {
pub shared: Arc<SharedComponents>,
pub mcp_orchestrator: Arc<McpOrchestrator>,
pub stateful_tool_bootstrapper: SharedStatefulToolBootstrapper,
pub response_storage: Arc<dyn ResponseStorage>,
pub conversation_storage: Arc<dyn ConversationStorage>,
pub conversation_item_storage: Arc<dyn ConversationItemStorage>,
Expand Down Expand Up @@ -88,6 +92,13 @@ impl ComponentRefs {
}
}

pub fn stateful_tool_bootstrapper(&self) -> Option<&SharedStatefulToolBootstrapper> {
match self {
ComponentRefs::Shared(_) => None,
ComponentRefs::Responses(r) => Some(&r.stateful_tool_bootstrapper),
}
}

pub fn conversation_storage(&self) -> Option<&Arc<dyn ConversationStorage>> {
match self {
ComponentRefs::Shared(_) => None,
Expand Down Expand Up @@ -132,6 +143,7 @@ pub struct PayloadState {
pub struct ResponsesPayloadState {
pub previous_response_id: Option<String>,
pub existing_mcp_list_tools_labels: Vec<String>,
pub stateful_tool_bootstrap: StatefulToolBootstrapState,
}

impl RequestContext {
Expand Down Expand Up @@ -268,6 +280,7 @@ pub struct OwnedStreamingContext {
pub original_body: ResponsesRequest,
pub previous_response_id: Option<String>,
pub existing_mcp_list_tools_labels: Vec<String>,
pub stateful_tool_bootstrap: StatefulToolBootstrapState,
pub storage: StorageHandles,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

Expand Down Expand Up @@ -306,6 +319,7 @@ impl RequestContext {
original_body,
previous_response_id: responses_payload_state.previous_response_id,
existing_mcp_list_tools_labels: responses_payload_state.existing_mcp_list_tools_labels,
stateful_tool_bootstrap: responses_payload_state.stateful_tool_bootstrap,
storage: StorageHandles {
response,
conversation,
Expand Down
28 changes: 25 additions & 3 deletions model_gateway/src/routers/openai/mcp/tool_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::{
mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS},
},
error,
openai::stateful_tools::StatefulToolBootstrapState,
},
};

Expand All @@ -49,14 +50,29 @@ pub(crate) struct ToolLoopState {
pub conversation_history: Vec<Value>,
/// Original user input (preserved for building resume payloads)
pub original_input: ResponseInput,
/// Request-scoped prepared state for hosted/stateful tools.
pub stateful_tool_bootstrap: StatefulToolBootstrapState,
/// MCP bindings already represented by historical `mcp_list_tools` items.
pub existing_mcp_list_tools_labels: HashSet<String>,
/// Transformed output items (mcp_call, web_search_call, etc.) - stored to avoid reconstruction
pub mcp_call_items: Vec<Value>,
}

impl ToolLoopState {
#[cfg(test)]
pub fn new(original_input: ResponseInput, prior_mcp_list_tools_labels: Vec<String>) -> Self {
Self::new_with_bootstrap(
original_input,
prior_mcp_list_tools_labels,
StatefulToolBootstrapState::default(),
)
}

pub fn new_with_bootstrap(
original_input: ResponseInput,
prior_mcp_list_tools_labels: Vec<String>,
stateful_tool_bootstrap: StatefulToolBootstrapState,
) -> Self {
Comment thread
RohanSogani marked this conversation as resolved.
let known_labels = prior_mcp_list_tools_labels
.into_iter()
.collect::<HashSet<_>>();
Expand All @@ -66,6 +82,7 @@ impl ToolLoopState {
total_calls: 0,
conversation_history: Vec::new(),
original_input,
stateful_tool_bootstrap,
existing_mcp_list_tools_labels: known_labels,
mcp_call_items: Vec::new(),
}
Expand Down Expand Up @@ -795,6 +812,7 @@ fn approval_prefix_items(
pub(crate) struct ToolLoopExecutionContext<'a> {
pub original_body: &'a ResponsesRequest,
pub existing_mcp_list_tools_labels: &'a [String],
pub stateful_tool_bootstrap: &'a StatefulToolBootstrapState,
pub session: &'a McpToolSession<'a>,
}

Expand All @@ -810,21 +828,25 @@ pub(crate) async fn execute_tool_loop(
let ToolLoopExecutionContext {
original_body,
existing_mcp_list_tools_labels,
stateful_tool_bootstrap,
session,
} = tool_loop_ctx;

let mut state = ToolLoopState::new(
let mut state = ToolLoopState::new_with_bootstrap(
original_body.input.clone(),
existing_mcp_list_tools_labels.to_vec(),
stateful_tool_bootstrap.clone(),
);
let max_tool_calls = original_body.max_tool_calls.map(|n| n as usize);
let base_payload = initial_payload.clone();
let tools_json = base_payload.get("tools").cloned().unwrap_or(json!([]));
let mut current_payload = initial_payload;

info!(
"Starting tool loop: max_tool_calls={:?}, max_iterations={}",
max_tool_calls, DEFAULT_MAX_ITERATIONS
"Starting tool loop: max_tool_calls={:?}, max_iterations={}, prepared_stateful_tools={}",
max_tool_calls,
DEFAULT_MAX_ITERATIONS,
state.stateful_tool_bootstrap.prepared_tools.len()
);
let provider = ApiProvider::from_url(url);
let auth_header = provider.extract_auth_header(headers, worker_api_key);
Expand Down
1 change: 1 addition & 0 deletions model_gateway/src/routers/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ mod provider;
pub mod realtime;
pub mod responses;
mod router;
pub mod stateful_tools;

pub use router::OpenAIRouter;
2 changes: 2 additions & 0 deletions model_gateway/src/routers/openai/responses/non_streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
let ResponsesPayloadState {
previous_response_id,
existing_mcp_list_tools_labels,
stateful_tool_bootstrap,
} = ctx.take_responses_payload().unwrap_or_default();
Comment thread
coderabbitai[bot] marked this conversation as resolved.

let original_body = match ctx.responses_request() {
Expand Down Expand Up @@ -97,6 +98,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response
ToolLoopExecutionContext {
original_body,
existing_mcp_list_tools_labels: &existing_mcp_list_tools_labels,
stateful_tool_bootstrap: &stateful_tool_bootstrap,
session: &session,
},
)
Expand Down
91 changes: 71 additions & 20 deletions model_gateway/src/routers/openai/responses/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ use serde_json::to_value;
use super::{
super::{
context::{
ComponentRefs, PayloadState, RequestContext, ResponsesComponents,
ComponentRefs, PayloadState, RequestContext, RequestType, ResponsesComponents,
ResponsesPayloadState, WorkerSelection,
},
provider::ProviderRegistry,
router::resolve_provider,
stateful_tools::{ensure_stateful_tool_bootstrap, StatefulToolBootstrapContext},
},
handle_non_streaming_response, handle_streaming_response,
};
Expand Down Expand Up @@ -132,6 +133,75 @@ pub(in crate::routers::openai) async fn route_responses(
super::history::inject_memory_context(&memory_config, &mut request_body);
}

let mut ctx = RequestContext::for_responses(
Arc::new(request_body.clone()),
headers.cloned(),
Comment thread
RohanSogani marked this conversation as resolved.
Some(model_id.to_string()),
ComponentRefs::Responses(Arc::clone(deps.responses_components)),
);
ctx.storage_request_context = smg_data_connector::current_request_context();
ctx.tenant_request_meta = Some(tenant_meta.clone());
let provider = resolve_provider(deps.provider_registry, worker.as_ref(), model);

ctx.state.worker = Some(WorkerSelection {
worker: Arc::clone(&worker),
provider: Arc::clone(&provider),
});
ctx.state.responses_payload = Some(ResponsesPayloadState {
previous_response_id: loaded_history.previous_response_id,
existing_mcp_list_tools_labels: loaded_history.existing_mcp_list_tools_labels,
..Default::default()
});

let stateful_tool_bootstrapper = match ctx.components.stateful_tool_bootstrapper() {
Some(bootstrapper) => Arc::clone(bootstrapper),
None => {
return error::internal_error(
"internal_error",
"Stateful tool bootstrapper not configured",
);
}
};
Comment thread
coderabbitai[bot] marked this conversation as resolved.
let storage_request_context = ctx.storage_request_context.clone();
let tenant_request_meta = ctx.tenant_request_meta.clone();
let memory_execution_context = ctx.memory_execution_context.clone();
let bootstrap_state = match ctx.state.responses_payload.as_mut() {
Some(responses_payload) => &mut responses_payload.stateful_tool_bootstrap,
None => {
return error::internal_error(
"internal_error",
"Responses payload state not initialized",
);
}
};
if let Err(e) = ensure_stateful_tool_bootstrap(
&mut request_body,
bootstrap_state,
stateful_tool_bootstrapper.as_ref(),
Comment thread
RohanSogani marked this conversation as resolved.
StatefulToolBootstrapContext {
Comment thread
RohanSogani marked this conversation as resolved.
headers,
storage_request_context: storage_request_context.as_ref(),
memory_execution_context: &memory_execution_context,
tenant_request_meta: tenant_request_meta.as_ref(),
},
)
.await
{
Metrics::record_router_error(
metrics_labels::ROUTER_OPENAI,
metrics_labels::BACKEND_EXTERNAL,
metrics_labels::CONNECTION_HTTP,
model,
metrics_labels::ENDPOINT_RESPONSES,
metrics_labels::ERROR_INTERNAL,
);
return error::internal_error(
"stateful_tool_bootstrap_failed",
format!("Failed to prepare stateful tool request state: {e}"),
);
}
Comment thread
RohanSogani marked this conversation as resolved.
ctx.input.request_type = RequestType::Responses(Arc::new(request_body.clone()));

request_body.store = Some(false);
if let ResponseInput::Items(ref mut items) = request_body.input {
items.retain(|item| !matches!(item, ResponseInputOutputItem::Reasoning { .. }));
Comment thread
RohanSogani marked this conversation as resolved.
Outdated
Expand All @@ -155,7 +225,6 @@ pub(in crate::routers::openai) async fn route_responses(
}
};

let provider = resolve_provider(deps.provider_registry, worker.as_ref(), model);
if let Err(e) = provider.transform_request(&mut payload, Endpoint::Responses) {
Metrics::record_router_error(
metrics_labels::ROUTER_OPENAI,
Expand All @@ -168,28 +237,10 @@ pub(in crate::routers::openai) async fn route_responses(
return error::bad_request("invalid_request", format!("Provider transform error: {e}"));
}

let mut ctx = RequestContext::for_responses(
Arc::new(body.clone()),
headers.cloned(),
Some(model_id.to_string()),
ComponentRefs::Responses(Arc::clone(deps.responses_components)),
);
ctx.storage_request_context = smg_data_connector::current_request_context();
ctx.tenant_request_meta = Some(tenant_meta.clone());

ctx.state.worker = Some(WorkerSelection {
worker: Arc::clone(&worker),
provider: Arc::clone(&provider),
});

ctx.state.payload = Some(PayloadState {
json: payload,
url: format!("{}/v1/responses", worker.url()),
});
ctx.state.responses_payload = Some(ResponsesPayloadState {
previous_response_id: loaded_history.previous_response_id,
existing_mcp_list_tools_labels: loaded_history.existing_mcp_list_tools_labels,
});

let response = if ctx.is_streaming() {
handle_streaming_response(ctx).await
Expand Down
8 changes: 7 additions & 1 deletion model_gateway/src/routers/openai/responses/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ pub(super) fn handle_streaming_with_tool_interception(
let original_request = req.original_body;
let previous_response_id = req.previous_response_id;
let existing_mcp_list_tools_labels = req.existing_mcp_list_tools_labels;
let stateful_tool_bootstrap = req.stateful_tool_bootstrap;
let url = req.url;
let storage = req.storage;

Expand All @@ -698,9 +699,14 @@ pub(super) fn handle_streaming_with_tool_interception(
reason = "fire-and-forget MCP tool loop; gateway shutdown need not wait for individual tool loops"
)]
tokio::spawn(async move {
let mut state = ToolLoopState::new(
let mut state = ToolLoopState::new_with_bootstrap(
original_request.input.clone(),
existing_mcp_list_tools_labels,
stateful_tool_bootstrap,
);
tracing::debug!(
prepared_stateful_tools = state.stateful_tool_bootstrap.prepared_tools.len(),
"Starting streaming tool loop"
);
let max_tool_calls = original_request.max_tool_calls.map(|n| n as usize);

Expand Down
11 changes: 10 additions & 1 deletion model_gateway/src/routers/openai/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use super::{
health,
provider::ProviderRegistry,
responses::route::{self as responses_route, ResponsesRouterContext},
stateful_tools::{NoOpStatefulToolBootstrapper, SharedStatefulToolBootstrapper},
};
use crate::{
app_context::AppContext,
Expand Down Expand Up @@ -76,11 +77,18 @@ impl std::fmt::Debug for OpenAIRouter {
}

impl OpenAIRouter {
pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
Self::new_with_stateful_tool_bootstrapper(ctx, Arc::new(NoOpStatefulToolBootstrapper)).await
}

#[expect(
clippy::unused_async,
reason = "async for API consistency with other router constructors"
)]
pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
pub async fn new_with_stateful_tool_bootstrapper(
ctx: &Arc<AppContext>,
stateful_tool_bootstrapper: SharedStatefulToolBootstrapper,
) -> Result<Self, String> {
let worker_registry = ctx.worker_registry.clone();
let mcp_orchestrator = ctx
.mcp_orchestrator
Expand All @@ -96,6 +104,7 @@ impl OpenAIRouter {
let responses_components = Arc::new(ResponsesComponents {
shared: Arc::clone(&shared_components),
mcp_orchestrator: mcp_orchestrator.clone(),
stateful_tool_bootstrapper,
response_storage: ctx.response_storage.clone(),
conversation_storage: ctx.conversation_storage.clone(),
conversation_item_storage: ctx.conversation_item_storage.clone(),
Expand Down
Loading
Loading