diff --git a/model_gateway/src/config/builder.rs b/model_gateway/src/config/builder.rs index 209daefdc..59dd6f9a5 100644 --- a/model_gateway/src/config/builder.rs +++ b/model_gateway/src/config/builder.rs @@ -183,6 +183,11 @@ impl RouterConfigBuilder { self } + pub fn max_conversation_history_items(mut self, max_items: usize) -> Self { + self.config.max_conversation_history_items = max_items; + self + } + pub fn worker_startup_timeout_secs(mut self, timeout: u64) -> Self { self.config.worker_startup_timeout_secs = timeout; self diff --git a/model_gateway/src/config/types.rs b/model_gateway/src/config/types.rs index 1a675bd4b..b01fefece 100644 --- a/model_gateway/src/config/types.rs +++ b/model_gateway/src/config/types.rs @@ -106,6 +106,9 @@ pub struct RouterConfig { pub storage_context_headers: HashMap, #[serde(default)] pub memory_runtime: MemoryRuntimeConfig, + /// Maximum conversation items to load into request context. + #[serde(default = "default_max_conversation_history_items")] + pub max_conversation_history_items: usize, #[serde(default)] pub background: BackgroundConfig, #[serde(default)] @@ -219,6 +222,10 @@ fn default_load_monitor_interval_secs() -> u64 { 10 } +fn default_max_conversation_history_items() -> usize { + 100 +} + fn default_enable_l0() -> bool { false } @@ -628,8 +635,9 @@ impl Default for RouterConfig { policy: PolicyConfig::Random, host: "0.0.0.0".to_string(), port: 3001, - max_payload_size: 536_870_912, // 512MB - request_timeout_secs: 1800, // 30 minutes + max_payload_size: 536_870_912, // 512MB + request_timeout_secs: 1800, // 30 minutes + max_conversation_history_items: default_max_conversation_history_items(), worker_startup_timeout_secs: 1800, // 30 minutes for large model loading worker_startup_check_interval_secs: 30, load_monitor_interval_secs: 10, @@ -766,6 +774,7 @@ mod tests { assert_eq!(config.port, 3001); assert_eq!(config.max_payload_size, 536_870_912); assert_eq!(config.request_timeout_secs, 1800); + assert_eq!(config.max_conversation_history_items, 100); assert_eq!(config.worker_startup_timeout_secs, 1800); assert_eq!(config.worker_startup_check_interval_secs, 30); assert_eq!(config.load_monitor_interval_secs, 10); @@ -955,6 +964,7 @@ stream_retention_secs: 3600 assert!(deserialized.skills_enabled); assert!(deserialized.skills.is_none()); + assert_eq!(deserialized.max_conversation_history_items, 100); assert!(!deserialized.tenant_resolution.trust_tenant_header); assert_eq!( deserialized.tenant_resolution.tenant_header_name, diff --git a/model_gateway/src/config/validation.rs b/model_gateway/src/config/validation.rs index 7dfe99d00..a3db30c0b 100644 --- a/model_gateway/src/config/validation.rs +++ b/model_gateway/src/config/validation.rs @@ -495,6 +495,14 @@ impl ConfigValidator { }); } + if config.max_conversation_history_items == 0 { + return Err(ConfigError::InvalidValue { + field: "max_conversation_history_items".to_string(), + value: config.max_conversation_history_items.to_string(), + reason: "Must be > 0".to_string(), + }); + } + if config.queue_size > 0 && config.queue_timeout_secs == 0 { return Err(ConfigError::InvalidValue { field: "queue_timeout_secs".to_string(), diff --git a/model_gateway/src/memory/context.rs b/model_gateway/src/memory/context.rs index 7395ca9f6..ff11e1004 100644 --- a/model_gateway/src/memory/context.rs +++ b/model_gateway/src/memory/context.rs @@ -48,6 +48,8 @@ pub struct MemoryExecutionContext { pub subject_id: Option, pub embedding_model: Option, pub extraction_model: Option, + pub stm_enabled: MemoryExecutionState, + pub stm_condenser_model_id: Option, } impl MemoryExecutionContext { @@ -68,6 +70,8 @@ impl MemoryExecutionContext { } let store_ltm_requested = policy.allows_ltm_store(); let recall_requested = policy.allows_recall(); + let stm_enabled = + MemoryExecutionState::from_requested_and_runtime(headers.stm_enabled, runtime.enabled); Self { store_ltm: MemoryExecutionState::from_requested_and_runtime( @@ -82,6 +86,11 @@ impl MemoryExecutionContext { subject_id: headers.subject_id.clone(), embedding_model: headers.embedding_model.clone(), extraction_model: headers.extraction_model.clone(), + stm_enabled, + stm_condenser_model_id: stm_enabled + .active() + .then_some(headers.stm_condenser_model_id.clone()) + .flatten(), } } } @@ -151,6 +160,8 @@ mod tests { fn store_and_recall_requested_but_not_active_when_runtime_disabled() { let headers = MemoryHeaderView { policy: Some("store_and_recall".to_string()), + stm_enabled: true, + stm_condenser_model_id: Some("condense-1".to_string()), ..MemoryHeaderView::default() }; @@ -159,6 +170,8 @@ mod tests { assert_eq!(ctx.store_ltm, MemoryExecutionState::GatedOff); assert_eq!(ctx.recall, MemoryExecutionState::GatedOff); assert_eq!(ctx.policy_mode, MemoryPolicyMode::StoreAndRecall); + assert_eq!(ctx.stm_enabled, MemoryExecutionState::GatedOff); + assert_eq!(ctx.stm_condenser_model_id, None); } #[test] @@ -167,7 +180,7 @@ mod tests { headers.insert( "x-conversation-memory-config", HeaderValue::from_static( - r#"{"long_term_memory":{"enabled":true,"policy":"store_and_recall","subject_id":" subject_abc ","embedding_model_id":" text-embedding-3-small ","extraction_model_id":" gpt-4.1-mini "}}"#, + r#"{"long_term_memory":{"enabled":true,"policy":"store_and_recall","subject_id":" subject_abc ","embedding_model_id":" text-embedding-3-small ","extraction_model_id":" gpt-4.1-mini "},"short_term_memory":{"enabled":true,"condenser_model_id":" condense-1 "}}"#, ), ); @@ -182,5 +195,7 @@ mod tests { Some("text-embedding-3-small") ); assert_eq!(ctx.extraction_model.as_deref(), Some("gpt-4.1-mini")); + assert_eq!(ctx.stm_enabled, MemoryExecutionState::Active); + assert_eq!(ctx.stm_condenser_model_id.as_deref(), Some("condense-1")); } } diff --git a/model_gateway/src/routers/common/header_utils.rs b/model_gateway/src/routers/common/header_utils.rs index 5528cf340..06f7ea467 100644 --- a/model_gateway/src/routers/common/header_utils.rs +++ b/model_gateway/src/routers/common/header_utils.rs @@ -21,6 +21,8 @@ pub struct MemoryHeaderView { pub subject_id: Option, pub embedding_model: Option, pub extraction_model: Option, + pub stm_enabled: bool, + pub stm_condenser_model_id: Option, } impl MemoryHeaderView { @@ -35,6 +37,7 @@ impl MemoryHeaderView { return Self::default(); }; let ltm_enabled = config.long_term_memory.enabled; + let stm_enabled = config.short_term_memory.enabled; let policy = if ltm_enabled { config .long_term_memory @@ -54,6 +57,10 @@ impl MemoryHeaderView { extraction_model: ltm_enabled .then_some(config.long_term_memory.extraction_model_id) .flatten(), + stm_enabled, + stm_condenser_model_id: stm_enabled + .then_some(config.short_term_memory.condenser_model_id) + .flatten(), } } } @@ -543,6 +550,8 @@ mod tests { assert_eq!(view.subject_id, None); assert_eq!(view.embedding_model, None); assert_eq!(view.extraction_model, None); + assert!(!view.stm_enabled); + assert_eq!(view.stm_condenser_model_id, None); } #[test] @@ -561,6 +570,8 @@ mod tests { assert_eq!(view.subject_id, None); assert_eq!(view.embedding_model, None); assert_eq!(view.extraction_model, None); + assert!(view.stm_enabled); + assert_eq!(view.stm_condenser_model_id.as_deref(), Some("cond-1")); } #[test] @@ -582,6 +593,8 @@ mod tests { Some("text-embedding-3-small") ); assert_eq!(view.extraction_model.as_deref(), Some("gpt-4.1-mini")); + assert!(!view.stm_enabled); + assert_eq!(view.stm_condenser_model_id, None); } #[test] diff --git a/model_gateway/src/routers/common/persistence_utils.rs b/model_gateway/src/routers/common/persistence_utils.rs index 40e4904cc..f273cfd4e 100644 --- a/model_gateway/src/routers/common/persistence_utils.rs +++ b/model_gateway/src/routers/common/persistence_utils.rs @@ -10,11 +10,14 @@ use openai_protocol::responses::{ use serde_json::{json, Value}; use smg_data_connector::{ with_request_context, ConversationId, ConversationItem, ConversationItemId, - ConversationItemStorage, ConversationStorage, NewConversationItem, + ConversationItemStorage, ConversationMemoryStatus, ConversationMemoryType, + ConversationMemoryWriter, ConversationStorage, NewConversationItem, NewConversationMemory, RequestContext as StorageRequestContext, ResponseId, ResponseStorage, StoredResponse, }; use tracing::{debug, info, warn}; +use crate::memory::MemoryExecutionContext; + // ============================================================================ // Constants // ============================================================================ @@ -375,6 +378,162 @@ async fn link_items_to_conversation( Ok(()) } +#[expect( + clippy::manual_is_multiple_of, + reason = "usize::is_multiple_of is not stable; % remainder check is the portable equivalent" +)] +fn should_enqueue_stmo(user_turns: usize) -> bool { + user_turns >= 4 && (user_turns - 1) % 3 == 0 +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct ConversationTurnInfo { + pub user_turns: usize, + pub total_items: usize, +} + +pub fn count_conversation_turn_info(input: &ResponseInput) -> ConversationTurnInfo { + match input { + ResponseInput::Text(_) => ConversationTurnInfo { + user_turns: 1, + total_items: 1, + }, + ResponseInput::Items(items) => { + let user_turns = items + .iter() + .filter(|item| match item { + ResponseInputOutputItem::SimpleInputMessage { role, .. } => { + role.eq_ignore_ascii_case("user") + } + ResponseInputOutputItem::Message { role, .. } => { + role.eq_ignore_ascii_case("user") + } + _ => false, + }) + .count(); + ConversationTurnInfo { + user_turns, + total_items: items.len(), + } + } + } +} + +async fn maybe_schedule_stmo_after_persist( + conversation_memory_writer: &Arc, + memory_execution_context: &MemoryExecutionContext, + conversation_id: &ConversationId, + response_id: &ResponseId, + user_turns: usize, + total_items: usize, +) -> Result { + if !memory_execution_context.stm_enabled.active() { + return Ok(false); + } + + if !should_enqueue_stmo(user_turns) { + return Ok(false); + } + + let mut memory_config = json!({ + "last_index": user_turns, + "target_item_end": total_items, + }); + + if let Some(model_id) = &memory_execution_context.stm_condenser_model_id { + memory_config["condenser_model"] = json!(model_id); + } + + let memory_config = serde_json::to_string(&memory_config) + .map_err(|e| format!("Failed to serialize STMO memory config: {e}"))?; + + let input = NewConversationMemory { + conversation_id: conversation_id.clone(), + conversation_version: None, + response_id: Some(response_id.clone()), + memory_type: ConversationMemoryType::Stmo, + status: ConversationMemoryStatus::Ready, + attempt: 0, + owner_id: None, + next_run_at: Utc::now(), + lease_until: None, + content: None, + memory_config: Some(memory_config), + scope_id: None, + error_msg: None, + }; + + conversation_memory_writer + .create_memory(input) + .await + .map_err(|e| format!("Failed to enqueue STMO memory: {e}"))?; + + Ok(true) +} + +async fn handle_stmo_after_persist( + conversation_memory_writer: &Arc, + memory_execution_context: &MemoryExecutionContext, + conversation_id: &ConversationId, + response_id: &ResponseId, + conversation_turn_info: Option, + output_item_count: usize, +) { + if !memory_execution_context.stm_enabled.active() { + return; + } + + let Some(turn_info) = conversation_turn_info else { + debug!( + conversation_id = %conversation_id.0, + response_id = %response_id.0, + "STMO skipped: missing conversation turn info" + ); + return; + }; + + let user_turns = turn_info.user_turns; + let total_items = turn_info.total_items + output_item_count; + + match maybe_schedule_stmo_after_persist( + conversation_memory_writer, + memory_execution_context, + conversation_id, + response_id, + user_turns, + total_items, + ) + .await + { + Ok(true) => { + info!( + conversation_id = %conversation_id.0, + response_id = %response_id.0, + user_turns, + total_items, + "Enqueued STMO memory condensation job" + ); + } + Ok(false) => { + debug!( + conversation_id = %conversation_id.0, + response_id = %response_id.0, + user_turns, + total_items, + "STMO not enqueued for this response boundary" + ); + } + Err(e) => { + warn!( + conversation_id = %conversation_id.0, + response_id = %response_id.0, + error = %e, + "Failed to enqueue STMO memory job; continuing without failing response" + ); + } + } +} + /// Persist conversation items to storage /// /// This function: @@ -382,20 +541,30 @@ async fn link_items_to_conversation( /// 2. Extracts output items from the response /// 3. Stores ALL items in response storage (always) /// 4. If conversation provided, also links items to conversation +#[expect( + clippy::too_many_arguments, + reason = "persistence entrypoint assembles all storage handles and request context in one call" +)] pub async fn persist_conversation_items( conversation_storage: Arc, item_storage: Arc, + conversation_memory_writer: Arc, response_storage: Arc, response_json: &Value, original_body: &ResponsesRequest, request_context: Option, + memory_execution_context: MemoryExecutionContext, + conversation_turn_info: Option, ) -> Result<(), String> { let inner = persist_conversation_items_inner( conversation_storage, item_storage, + conversation_memory_writer, response_storage, response_json, original_body, + memory_execution_context, + conversation_turn_info, ); match request_context { Some(ctx) => with_request_context(ctx, inner).await, @@ -403,12 +572,19 @@ pub async fn persist_conversation_items( } } +#[expect( + clippy::too_many_arguments, + reason = "inner persistence fn assembles all storage handles, memory context, and turn info in one flow" +)] async fn persist_conversation_items_inner( conversation_storage: Arc, item_storage: Arc, + conversation_memory_writer: Arc, response_storage: Arc, response_json: &Value, original_body: &ResponsesRequest, + memory_execution_context: MemoryExecutionContext, + conversation_turn_info: Option, ) -> Result<(), String> { // Respect store=false: skip persistence entirely (matches official API behavior) if !original_body.store.unwrap_or(true) { @@ -467,6 +643,17 @@ async fn persist_conversation_items_inner( response_id_str, ) .await?; + + handle_stmo_after_persist( + &conversation_memory_writer, + &memory_execution_context, + &conv_id, + &response_id, + conversation_turn_info, + output_items.len(), + ) + .await; + info!( conversation_id = %conv_id.0, response_id = %response_id.0, @@ -485,3 +672,110 @@ async fn persist_conversation_items_inner( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn stmo_fires_at_boundary_turns() { + // Fires at 4, 7, 10 — not before or between + assert!(!should_enqueue_stmo(3)); + assert!(should_enqueue_stmo(4)); + assert!(!should_enqueue_stmo(5)); + assert!(!should_enqueue_stmo(6)); + assert!(should_enqueue_stmo(7)); + assert!(should_enqueue_stmo(10)); + } + + #[test] + fn count_user_turns_case_insensitive() { + let input = ResponseInput::Items(vec![ + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("u1".to_string()), + role: "user".to_string(), + r#type: None, + phase: None, + }, + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("a1".to_string()), + role: "assistant".to_string(), + r#type: None, + phase: None, + }, + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("u2".to_string()), + role: "User".to_string(), + r#type: None, + phase: None, + }, + ResponseInputOutputItem::FunctionToolCall { + id: "fc_1".to_string(), + call_id: "call_1".to_string(), + name: "tool".to_string(), + arguments: "{}".to_string(), + output: None, + status: None, + }, + ]); + let info = count_conversation_turn_info(&input); + assert_eq!(info.user_turns, 2); + assert_eq!(info.total_items, 4); + } + + #[test] + fn total_items_exceeds_user_turns() { + let input = ResponseInput::Items(vec![ + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("u1".to_string()), + role: "user".to_string(), + r#type: None, + phase: None, + }, + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("a1".to_string()), + role: "assistant".to_string(), + r#type: None, + phase: None, + }, + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("u2".to_string()), + role: "user".to_string(), + r#type: None, + phase: None, + }, + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("a2".to_string()), + role: "assistant".to_string(), + r#type: None, + phase: None, + }, + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("u3".to_string()), + role: "user".to_string(), + r#type: None, + phase: None, + }, + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("a3".to_string()), + role: "assistant".to_string(), + r#type: None, + phase: None, + }, + ResponseInputOutputItem::SimpleInputMessage { + content: StringOrContentParts::String("u4".to_string()), + role: "user".to_string(), + r#type: None, + phase: None, + }, + ]); + + let info = count_conversation_turn_info(&input); + let target_item_end = info.total_items + 2; // response output contributes after load + + assert_eq!(info.user_turns, 4); + assert_eq!(info.total_items, 7); + assert_eq!(target_item_end, 9); // total > user turns + assert!(should_enqueue_stmo(info.user_turns)); + } +} diff --git a/model_gateway/src/routers/grpc/common/responses/context.rs b/model_gateway/src/routers/grpc/common/responses/context.rs index d5835a1dc..38b46cc35 100644 --- a/model_gateway/src/routers/grpc/common/responses/context.rs +++ b/model_gateway/src/routers/grpc/common/responses/context.rs @@ -10,7 +10,24 @@ use smg_data_connector::{ }; use smg_mcp::McpOrchestrator; -use crate::routers::grpc::{context::SharedComponents, pipeline::RequestPipeline}; +use crate::{ + memory::MemoryExecutionContext, + routers::grpc::{context::SharedComponents, pipeline::RequestPipeline}, +}; + +/// Bundled storage handles for persistence operations. +/// +/// Groups the four storage backends that every persistence call needs so they +/// can be passed as a single unit rather than four individual arguments. +/// Mirrors the pattern introduced in the LTM pipeline (PR #1357) so the two +/// code paths stay consistent and future merges remain clean. +#[derive(Clone)] +pub(crate) struct PersistenceHandles { + pub response_storage: Arc, + pub conversation_storage: Arc, + pub conversation_item_storage: Arc, + pub conversation_memory_writer: Arc, +} /// Context for /v1/responses endpoint /// @@ -24,50 +41,46 @@ pub(crate) struct ResponsesContext { /// Shared components (tokenizer, parsers) pub components: Arc, - /// Response storage backend - pub response_storage: Arc, - - /// Conversation storage backend - pub conversation_storage: Arc, - - /// Conversation item storage backend - pub conversation_item_storage: Arc, - - /// Conversation memory writer (can be NoOp depending on backend) - pub conversation_memory_writer: Arc, + /// Bundled storage handles for persistence operations. + pub persistence: PersistenceHandles, /// MCP orchestrator for tool support pub mcp_orchestrator: Arc, /// Storage hook request context extracted from HTTP headers by middleware. pub request_context: Option, + + /// Maximum conversation history items to load into request context. + pub max_conversation_history_items: usize, + + /// Memory execution context derived from per-request headers. + /// + /// Controls whether LTM store/recall and STM condensation are active for + /// this request. Built from `x-conversation-memory-config` + the runtime + /// feature flag at the gRPC entry point and threaded down to the persistence + /// layer so it can gate memory side-effects without re-parsing headers. + pub memory_execution_context: MemoryExecutionContext, } impl ResponsesContext { /// Create a new responses context. - #[expect( - clippy::too_many_arguments, - reason = "responses context assembles shared pipeline + storage handles in one place" - )] pub fn new( pipeline: Arc, components: Arc, - response_storage: Arc, - conversation_storage: Arc, - conversation_item_storage: Arc, - conversation_memory_writer: Arc, + persistence: PersistenceHandles, mcp_orchestrator: Arc, request_context: Option, + max_conversation_history_items: usize, + memory_execution_context: MemoryExecutionContext, ) -> Self { Self { pipeline, components, - response_storage, - conversation_storage, - conversation_item_storage, - conversation_memory_writer, + persistence, mcp_orchestrator, request_context, + max_conversation_history_items, + memory_execution_context, } } } diff --git a/model_gateway/src/routers/grpc/common/responses/handlers.rs b/model_gateway/src/routers/grpc/common/responses/handlers.rs index 7f2791c98..6e0448f2f 100644 --- a/model_gateway/src/routers/grpc/common/responses/handlers.rs +++ b/model_gateway/src/routers/grpc/common/responses/handlers.rs @@ -16,7 +16,12 @@ pub(crate) async fn cancel_response_impl(ctx: &ResponsesContext, response_id: &s let resp_id = ResponseId::from(response_id); // Check if response exists - match ctx.response_storage.get_response(&resp_id).await { + match ctx + .persistence + .response_storage + .get_response(&resp_id) + .await + { Ok(Some(stored_response)) => { let current_status = stored_response .raw_response diff --git a/model_gateway/src/routers/grpc/common/responses/mod.rs b/model_gateway/src/routers/grpc/common/responses/mod.rs index fd94e65ff..02ca02f4c 100644 --- a/model_gateway/src/routers/grpc/common/responses/mod.rs +++ b/model_gateway/src/routers/grpc/common/responses/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod streaming; pub(crate) mod utils; // Re-export commonly used items -pub(crate) use context::ResponsesContext; +pub(crate) use context::{PersistenceHandles, ResponsesContext}; pub(crate) use streaming::build_sse_response; pub(crate) use utils::{ensure_mcp_connection, persist_response_if_needed}; diff --git a/model_gateway/src/routers/grpc/common/responses/utils.rs b/model_gateway/src/routers/grpc/common/responses/utils.rs index 66b09b63d..dfb8ccf44 100644 --- a/model_gateway/src/routers/grpc/common/responses/utils.rs +++ b/model_gateway/src/routers/grpc/common/responses/utils.rs @@ -8,19 +8,19 @@ use openai_protocol::{ responses::{ResponseTool, ResponsesRequest, ResponsesResponse}, }; use serde_json::to_value; -use smg_data_connector::{ - ConversationItemStorage, ConversationStorage, RequestContext as StorageRequestContext, - ResponseStorage, -}; +use smg_data_connector::RequestContext as StorageRequestContext; use smg_mcp::{McpOrchestrator, McpServerBinding}; use tracing::{debug, error, warn}; use crate::{ + memory::MemoryExecutionContext, routers::{ common::{ - mcp_utils::ensure_request_mcp_client, persistence_utils::persist_conversation_items, + mcp_utils::ensure_request_mcp_client, + persistence_utils::{persist_conversation_items, ConversationTurnInfo}, }, error, + grpc::common::responses::context::PersistenceHandles, }, worker::WorkerRegistry, }; @@ -151,9 +151,9 @@ pub(crate) fn extract_tools_from_response_tools( /// Common helper function to avoid duplication across sync and streaming paths /// in both harmony and regular responses implementations. pub(crate) async fn persist_response_if_needed( - conversation_storage: Arc, - conversation_item_storage: Arc, - response_storage: Arc, + persistence: &PersistenceHandles, + memory_execution_context: MemoryExecutionContext, + conversation_turn_info: Option, response: &ResponsesResponse, original_request: &ResponsesRequest, request_context: Option, @@ -164,12 +164,15 @@ pub(crate) async fn persist_response_if_needed( if let Ok(response_json) = to_value(response) { if let Err(e) = persist_conversation_items( - conversation_storage, - conversation_item_storage, - response_storage, + persistence.conversation_storage.clone(), + persistence.conversation_item_storage.clone(), + persistence.conversation_memory_writer.clone(), + persistence.response_storage.clone(), &response_json, original_request, request_context, + memory_execution_context, + conversation_turn_info, ) .await { diff --git a/model_gateway/src/routers/grpc/harmony/responses/common.rs b/model_gateway/src/routers/grpc/harmony/responses/common.rs index 24770772c..c87f24568 100644 --- a/model_gateway/src/routers/grpc/harmony/responses/common.rs +++ b/model_gateway/src/routers/grpc/harmony/responses/common.rs @@ -18,7 +18,11 @@ use tracing::{debug, error, warn}; use uuid::Uuid; use super::execution::ToolResult; -use crate::routers::{error, grpc::common::responses::ResponsesContext}; +use crate::routers::{ + common::persistence_utils::{count_conversation_turn_info, ConversationTurnInfo}, + error, + grpc::common::responses::ResponsesContext, +}; /// Record of a single MCP tool call execution /// @@ -40,6 +44,12 @@ pub(super) struct McpCallTracking { pub tool_calls: Vec, } +/// Loaded request bundle for Harmony Responses path. +pub(super) struct LoadedRequest { + pub request: ResponsesRequest, + pub turn_info: Option, +} + impl McpCallTracking { pub fn new() -> Self { Self { @@ -191,16 +201,25 @@ pub(super) fn inject_mcp_metadata( pub(super) async fn load_previous_messages( ctx: &ResponsesContext, request: ResponsesRequest, -) -> Result { + stm_enabled: bool, +) -> Result { let Some(ref prev_id_str) = request.previous_response_id else { - // No previous_response_id, return request as-is - return Ok(request); + // No previous_response_id: return request as-is. + // If a conversation ID is present we have not loaded its history here, + // so turn counts from request.input alone would be wrong — skip STMO. + let turn_info = if stm_enabled && request.conversation.is_none() { + Some(count_conversation_turn_info(&request.input)) + } else { + None + }; + return Ok(LoadedRequest { request, turn_info }); }; let prev_id = ResponseId::from(prev_id_str.as_str()); // Load response chain from storage let chain = match ctx + .persistence .response_storage .get_response_chain(&prev_id, None) .await @@ -298,7 +317,16 @@ pub(super) async fn load_previous_messages( modified_request.input = ResponseInput::Items(all_items); modified_request.previous_response_id = None; - Ok(modified_request) + let turn_info = if stm_enabled { + Some(count_conversation_turn_info(&modified_request.input)) + } else { + None + }; + + Ok(LoadedRequest { + request: modified_request, + turn_info, + }) } /// Strip `ResponseTool::ImageGeneration` from a request's tools list once diff --git a/model_gateway/src/routers/grpc/harmony/responses/non_streaming.rs b/model_gateway/src/routers/grpc/harmony/responses/non_streaming.rs index 3fd641141..727114033 100644 --- a/model_gateway/src/routers/grpc/harmony/responses/non_streaming.rs +++ b/model_gateway/src/routers/grpc/harmony/responses/non_streaming.rs @@ -60,7 +60,14 @@ pub(crate) async fn serve_harmony_responses( let original_request = request.clone(); // Load previous conversation history if previous_response_id is set - let current_request = load_previous_messages(ctx, request).await?; + let loaded_request = load_previous_messages( + ctx, + request, + ctx.memory_execution_context.stm_enabled.active(), + ) + .await?; + let current_request = loaded_request.request; + let conversation_turn_info = loaded_request.turn_info; // Check MCP connection and get whether MCP tools are present let (has_mcp_tools, mcp_servers) = @@ -81,9 +88,9 @@ pub(crate) async fn serve_harmony_responses( // Persist response to storage if store=true persist_response_if_needed( - ctx.conversation_storage.clone(), - ctx.conversation_item_storage.clone(), - ctx.response_storage.clone(), + &ctx.persistence, + ctx.memory_execution_context.clone(), + conversation_turn_info, &response, &original_request, ctx.request_context.clone(), diff --git a/model_gateway/src/routers/grpc/harmony/responses/streaming.rs b/model_gateway/src/routers/grpc/harmony/responses/streaming.rs index 496a348af..0a7133f11 100644 --- a/model_gateway/src/routers/grpc/harmony/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/harmony/responses/streaming.rs @@ -22,7 +22,7 @@ use crate::{ middleware::TenantRequestMeta, observability::metrics::Metrics, routers::{ - common::mcp_utils::DEFAULT_MAX_ITERATIONS, + common::{mcp_utils::DEFAULT_MAX_ITERATIONS, persistence_utils::ConversationTurnInfo}, grpc::{ common::responses::{ build_sse_response, ensure_mcp_connection, persist_response_if_needed, @@ -47,10 +47,18 @@ pub(crate) async fn serve_harmony_responses_stream( tenant_request_meta: TenantRequestMeta, ) -> Response { // Load previous conversation history if previous_response_id is set - let current_request = match load_previous_messages(ctx, request.clone()).await { + let loaded_request = match load_previous_messages( + ctx, + request.clone(), + ctx.memory_execution_context.stm_enabled.active(), + ) + .await + { Ok(req) => req, Err(err_response) => return err_response, }; + let current_request = loaded_request.request; + let conversation_turn_info = loaded_request.turn_info; // Check MCP connection BEFORE starting stream and get whether MCP tools are present let (has_mcp_tools, mcp_servers) = match ensure_mcp_connection( @@ -102,6 +110,7 @@ pub(crate) async fn serve_harmony_responses_stream( &request, tenant_request_meta.clone(), mcp_servers, + conversation_turn_info, &mut emitter, &tx, ) @@ -112,6 +121,7 @@ pub(crate) async fn serve_harmony_responses_stream( ¤t_request, &request, tenant_request_meta, + conversation_turn_info, &mut emitter, &tx, ) @@ -131,12 +141,17 @@ pub(crate) async fn serve_harmony_responses_stream( /// - Loops through tool execution iterations /// - Emits final response.completed event /// - Persists response internally +#[expect( + clippy::too_many_arguments, + reason = "streaming MCP loop threads ctx, requests, MCP state, turn info, emitter, and tx together" +)] async fn execute_mcp_tool_loop_streaming( ctx: &ResponsesContext, mut current_request: ResponsesRequest, original_request: &ResponsesRequest, tenant_request_meta: TenantRequestMeta, mcp_servers: Vec, + conversation_turn_info: Option, emitter: &mut ResponseStreamEventEmitter, tx: &mpsc::UnboundedSender>, ) { @@ -389,9 +404,9 @@ async fn execute_mcp_tool_loop_streaming( // Persist response to storage if store=true persist_response_if_needed( - ctx.conversation_storage.clone(), - ctx.conversation_item_storage.clone(), - ctx.response_storage.clone(), + &ctx.persistence, + ctx.memory_execution_context.clone(), + conversation_turn_info, &final_response, original_request, ctx.request_context.clone(), @@ -427,6 +442,7 @@ async fn execute_without_mcp_streaming( current_request: &ResponsesRequest, original_request: &ResponsesRequest, tenant_request_meta: TenantRequestMeta, + conversation_turn_info: Option, emitter: &mut ResponseStreamEventEmitter, tx: &mpsc::UnboundedSender>, ) { @@ -477,9 +493,9 @@ async fn execute_without_mcp_streaming( // Persist response to storage if store=true persist_response_if_needed( - ctx.conversation_storage.clone(), - ctx.conversation_item_storage.clone(), - ctx.response_storage.clone(), + &ctx.persistence, + ctx.memory_execution_context.clone(), + conversation_turn_info, &final_response, original_request, ctx.request_context.clone(), diff --git a/model_gateway/src/routers/grpc/regular/responses/common.rs b/model_gateway/src/routers/grpc/regular/responses/common.rs index a152ba934..15a05876c 100644 --- a/model_gateway/src/routers/grpc/regular/responses/common.rs +++ b/model_gateway/src/routers/grpc/regular/responses/common.rs @@ -24,7 +24,10 @@ use tracing::{debug, warn}; use crate::{ middleware::TenantRequestMeta, routers::{ - common::persistence_utils::split_stored_message_content, error, + common::persistence_utils::{ + count_conversation_turn_info, split_stored_message_content, ConversationTurnInfo, + }, + error, grpc::common::responses::ResponsesContext, }, }; @@ -51,6 +54,12 @@ pub(super) struct ResponsesCallContext { pub tenant_request_meta: TenantRequestMeta, } +/// Loaded request bundle for Regular Responses path. +pub(super) struct LoadedRequest { + pub request: ResponsesRequest, + pub turn_info: Option, +} + impl ToolLoopState { pub fn new(original_input: ResponseInput) -> Self { Self { @@ -165,14 +174,21 @@ pub(super) fn convert_mcp_tools_to_chat_tools(session: &McpToolSession<'_>) -> V pub(super) async fn load_conversation_history( ctx: &ResponsesContext, request: &ResponsesRequest, -) -> Result { + stm_enabled: bool, +) -> Result { let mut modified_request = request.clone(); let mut conversation_items: Option> = None; + // Tracks the raw DB item count (all types) for the conversation path so + // that total_items in ConversationTurnInfo is not undercounted when + // function_call/function_call_output/MCP items are present in storage + // but filtered out of the inference window. + let mut raw_stored_item_count: Option = None; // Handle previous_response_id by loading response chain if let Some(ref prev_id_str) = modified_request.previous_response_id { let prev_id = ResponseId::from(prev_id_str.as_str()); match ctx + .persistence .response_storage .get_response_chain(&prev_id, None) .await @@ -241,6 +257,7 @@ pub(super) async fn load_conversation_history( // Check if conversation exists - return error if not found let conversation = ctx + .persistence .conversation_storage .get_conversation(&conv_id) .await @@ -260,20 +277,39 @@ pub(super) async fn load_conversation_history( )); } - // Load conversation history - const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100; + // Load conversation history. + // Fetch up to cap items. If we get cap rows back the conversation is at + // or beyond the configured limit; reject when STMO is active and the + // response will be persisted. The SQL LIMIT also bounds the inference + // window, so no post-fetch truncation is needed. + let cap = ctx.max_conversation_history_items; let params = data_connector::ListParams { - limit: MAX_CONVERSATION_HISTORY_ITEMS, + limit: cap, order: data_connector::SortOrder::Asc, after: None, }; - match ctx + .persistence .conversation_item_storage .list_items(&conv_id, params) .await { Ok(stored_items) => { + // Only reject oversized conversations when the response will + // actually be persisted. store=false requests skip persistence + // entirely, so STMO is never enqueued and the cap does not apply. + if stm_enabled && request.store.unwrap_or(true) && stored_items.len() >= cap { + return Err(error::bad_request( + "conversation_too_large", + format!( + "Conversation has reached the configured limit of {cap} history items. \ + Increase max_conversation_history_items in the router config \ + or reduce conversation length before using short-term memory \ + optimization.", + ), + )); + } + raw_stored_item_count = Some(stored_items.len()); let mut items: Vec = Vec::new(); for item in stored_items { if item.item_type == "message" { @@ -358,7 +394,46 @@ pub(super) async fn load_conversation_history( "Loaded conversation history" ); - Ok(modified_request) + let turn_info = if stm_enabled { + // If a conversation was requested but list_items failed, the assembled + // input only contains the current request — STMO turn counts would be + // wrong. Skip STMO for this request so persistence does not enqueue a + // job with an undercounted last_index/target_item_end. + if request.conversation.is_some() && raw_stored_item_count.is_none() { + None + } else { + let mut info = count_conversation_turn_info(&modified_request.input); + // If we loaded from conversation storage, total_items from the + // assembled (message-only) input undercounts — function_call, + // function_call_output, and MCP items are in the DB but filtered + // out of the inference window. Use the raw DB count instead so + // target_item_end points at the correct absolute position. + if let Some(raw_count) = raw_stored_item_count { + // Only apply the raw-count correction when no response chain + // was also loaded. If previous_response_id was set, the chain + // merge ran last and overwrote modified_request.input with the + // full replayed history — count_conversation_turn_info already + // saw every item, so no correction is needed. (conversation and + // previous_response_id are mutually exclusive in the API, but + // we guard here for safety.) + if request.previous_response_id.is_none() { + let current_input_count = match &request.input { + ResponseInput::Text(_) => 1, + ResponseInput::Items(items) => items.len(), + }; + info.total_items = raw_count + current_input_count; + } + } + Some(info) + } + } else { + None + }; + + Ok(LoadedRequest { + request: modified_request, + turn_info, + }) } /// Build next request with updated conversation history diff --git a/model_gateway/src/routers/grpc/regular/responses/handlers.rs b/model_gateway/src/routers/grpc/regular/responses/handlers.rs index bcad8eeb7..070a4f745 100644 --- a/model_gateway/src/routers/grpc/regular/responses/handlers.rs +++ b/model_gateway/src/routers/grpc/regular/responses/handlers.rs @@ -111,19 +111,28 @@ async fn route_responses_streaming( request: Arc, params: ResponsesCallContext, ) -> Response { - // 1. Load conversation history - let modified_request = match load_conversation_history(ctx, &request).await { - Ok(req) => req, - Err(response) => return response, // Already a Response with proper status code - }; - - // 2. Check MCP connection and get whether MCP tools are present + // 1. Check MCP connection first so we can gate STMO logic correctly. + // ensure_mcp_connection only inspects request.tools, not request.input, + // so it is safe to call before history loading. let (has_mcp_tools, mcp_servers) = match ensure_mcp_connection(&ctx.mcp_orchestrator, request.tools.as_deref()).await { Ok(result) => result, Err(response) => return response, }; + // 2. Load conversation history. + // The MCP streaming path (execute_tool_loop_streaming) never calls + // persist_response_if_needed, so STMO is never enqueued there. Disable + // stm_enabled for that path so the cap overflow check does not reject + // requests that will never trigger STMO anyway. + let stm_enabled = ctx.memory_execution_context.stm_enabled.active() && !has_mcp_tools; + let loaded_request = match load_conversation_history(ctx, &request, stm_enabled).await { + Ok(req) => req, + Err(response) => return response, + }; + let modified_request = loaded_request.request; + let conversation_turn_info = loaded_request.turn_info; + if has_mcp_tools { debug!("MCP tools detected in streaming mode, using streaming tool loop"); @@ -148,5 +157,12 @@ async fn route_responses_streaming( }; // 4. Execute chat pipeline and convert streaming format (no MCP tools) - streaming::convert_chat_stream_to_responses_stream(ctx, chat_request, params, &request).await + streaming::convert_chat_stream_to_responses_stream( + ctx, + chat_request, + params, + &request, + conversation_turn_info, + ) + .await } diff --git a/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs b/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs index 483092800..d531215d1 100644 --- a/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs +++ b/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs @@ -46,7 +46,14 @@ pub(super) async fn route_responses_internal( params: ResponsesCallContext, ) -> Result { // 1. Load conversation history and build modified request - let modified_request = load_conversation_history(ctx, &request).await?; + let loaded_request = load_conversation_history( + ctx, + &request, + ctx.memory_execution_context.stm_enabled.active(), + ) + .await?; + let modified_request = loaded_request.request; + let conversation_turn_info = loaded_request.turn_info; // 2. Check MCP connection and get whether MCP tools are present let (has_mcp_tools, mcp_servers) = @@ -64,9 +71,9 @@ pub(super) async fn route_responses_internal( // 5. Persist response to storage if store=true persist_response_if_needed( - ctx.conversation_storage.clone(), - ctx.conversation_item_storage.clone(), - ctx.response_storage.clone(), + &ctx.persistence, + ctx.memory_execution_context.clone(), + conversation_turn_info, &responses_response, &request, ctx.request_context.clone(), diff --git a/model_gateway/src/routers/grpc/regular/responses/streaming.rs b/model_gateway/src/routers/grpc/regular/responses/streaming.rs index d269808d5..20c709ec8 100644 --- a/model_gateway/src/routers/grpc/regular/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/responses/streaming.rs @@ -30,10 +30,7 @@ use openai_protocol::{ }, }; use serde_json::{json, Value}; -use smg_data_connector::{ - ConversationItemStorage, ConversationStorage, RequestContext as StorageRequestContext, - ResponseStorage, -}; +use smg_data_connector::RequestContext as StorageRequestContext; use smg_mcp::{McpServerBinding, McpToolSession, ResponseFormat, ToolExecutionInput}; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -48,14 +45,18 @@ use super::{ conversions, }; use crate::{ + memory::MemoryExecutionContext, observability::metrics::{metrics_labels, Metrics}, routers::{ - common::mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS}, + common::{ + mcp_utils::{prepare_hosted_dispatch_args, DEFAULT_MAX_ITERATIONS}, + persistence_utils::ConversationTurnInfo, + }, grpc::{ common::responses::{ build_sse_response, persist_response_if_needed, streaming::{attach_mcp_server_label, OutputItemType, ResponseStreamEventEmitter}, - ResponsesContext, + PersistenceHandles, ResponsesContext, }, utils, }, @@ -79,6 +80,7 @@ pub(super) async fn convert_chat_stream_to_responses_stream( chat_request: Arc, params: ResponsesCallContext, original_request: &ResponsesRequest, + conversation_turn_info: Option, ) -> Response { debug!("Converting chat SSE stream to responses SSE format"); @@ -102,10 +104,9 @@ pub(super) async fn convert_chat_stream_to_responses_stream( // Spawn background task to transform stream let original_request_clone = original_request.clone(); - let response_storage = ctx.response_storage.clone(); - let conversation_storage = ctx.conversation_storage.clone(); - let conversation_item_storage = ctx.conversation_item_storage.clone(); + let persistence = ctx.persistence.clone(); let request_context = ctx.request_context.clone(); + let memory_execution_context = ctx.memory_execution_context.clone(); #[expect( clippy::disallowed_methods, @@ -115,10 +116,10 @@ pub(super) async fn convert_chat_stream_to_responses_stream( if let Err(e) = process_and_transform_sse_stream( body, original_request_clone, - response_storage, - conversation_storage, - conversation_item_storage, + persistence, request_context, + memory_execution_context, + conversation_turn_info, tx.clone(), ) .await @@ -139,10 +140,10 @@ pub(super) async fn convert_chat_stream_to_responses_stream( async fn process_and_transform_sse_stream( body: Body, original_request: ResponsesRequest, - response_storage: Arc, - conversation_storage: Arc, - conversation_item_storage: Arc, + persistence: PersistenceHandles, request_context: Option, + memory_execution_context: MemoryExecutionContext, + conversation_turn_info: Option, tx: mpsc::UnboundedSender>, ) -> Result<(), String> { // Create accumulator for final response @@ -231,9 +232,9 @@ async fn process_and_transform_sse_stream( // Finalize and persist accumulated response let final_response = accumulator.finalize(); persist_response_if_needed( - conversation_storage, - conversation_item_storage, - response_storage, + &persistence, + memory_execution_context, + conversation_turn_info, &final_response, &original_request, request_context, diff --git a/model_gateway/src/routers/grpc/router.rs b/model_gateway/src/routers/grpc/router.rs index 524da0773..153e20e2f 100644 --- a/model_gateway/src/routers/grpc/router.rs +++ b/model_gateway/src/routers/grpc/router.rs @@ -14,7 +14,8 @@ use tracing::debug; use super::{ common::responses::{ - handlers::cancel_response_impl, utils::validate_worker_availability, ResponsesContext, + handlers::cancel_response_impl, utils::validate_worker_availability, PersistenceHandles, + ResponsesContext, }, context::SharedComponents, harmony::{serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector}, @@ -24,7 +25,8 @@ use super::{ }; use crate::{ app_context::AppContext, - config::types::RetryConfig, + config::types::{MemoryRuntimeConfig, RetryConfig}, + memory::MemoryExecutionContext, middleware::TenantRequestMeta, observability::metrics::{metrics_labels, Metrics}, routers::{ @@ -47,6 +49,8 @@ pub struct GrpcRouter { shared_components: Arc, responses_context: ResponsesContext, harmony_responses_context: ResponsesContext, + max_conversation_history_items: usize, + memory_runtime_config: MemoryRuntimeConfig, retry_config: RetryConfig, } @@ -139,18 +143,27 @@ impl GrpcRouter { // Capture storage request context from middleware task-local (before any spawn) let storage_request_context = smg_data_connector::current_request_context(); + let max_conversation_history_items = ctx.router_config.max_conversation_history_items; + let memory_runtime_config = ctx.router_config.memory_runtime.clone(); - // Helper closure to create responses context with a given pipeline + // Helper closure to create responses context with a given pipeline. + // Uses MemoryExecutionContext::default() because these contexts are built at + // startup before any request headers are available. Per-request contexts are + // rebuilt inside route_responses_impl with the actual request headers. let create_responses_context = |pipeline: &RequestPipeline| { ResponsesContext::new( Arc::new(pipeline.clone()), shared_components.clone(), - ctx.response_storage.clone(), - ctx.conversation_storage.clone(), - ctx.conversation_item_storage.clone(), - ctx.conversation_memory_writer.clone(), + PersistenceHandles { + response_storage: ctx.response_storage.clone(), + conversation_storage: ctx.conversation_storage.clone(), + conversation_item_storage: ctx.conversation_item_storage.clone(), + conversation_memory_writer: ctx.conversation_memory_writer.clone(), + }, mcp_orchestrator.clone(), storage_request_context.clone(), + max_conversation_history_items, + MemoryExecutionContext::default(), ) }; @@ -169,6 +182,8 @@ impl GrpcRouter { shared_components, responses_context, harmony_responses_context, + max_conversation_history_items, + memory_runtime_config, retry_config: ctx.router_config.effective_retry_config(), }) } @@ -325,6 +340,15 @@ impl GrpcRouter { let is_harmony = HarmonyDetector::is_harmony_model_in_registry(&self.worker_registry, &body.model); + // Build memory execution context from per-request headers. gRPC clients map + // metadata to HTTP/2 headers, so headers may carry memory config just like + // the HTTP path. An absent or empty header produces the default (no-op) context. + let empty_headers = HeaderMap::new(); + let memory_execution_context = MemoryExecutionContext::from_http_headers( + headers.unwrap_or(&empty_headers), + &self.memory_runtime_config, + ); + if is_harmony { debug!( "Processing Harmony responses request for model: {}, streaming: {}", @@ -334,16 +358,11 @@ impl GrpcRouter { let harmony_ctx = ResponsesContext::new( Arc::new(self.harmony_pipeline.clone()), self.shared_components.clone(), - self.harmony_responses_context.response_storage.clone(), - self.harmony_responses_context.conversation_storage.clone(), - self.harmony_responses_context - .conversation_item_storage - .clone(), - self.harmony_responses_context - .conversation_memory_writer - .clone(), + self.harmony_responses_context.persistence.clone(), self.harmony_responses_context.mcp_orchestrator.clone(), smg_data_connector::current_request_context(), + self.max_conversation_history_items, + memory_execution_context, ); if body.stream.unwrap_or(false) { @@ -357,8 +376,20 @@ impl GrpcRouter { } } } else { + // Build a fresh per-request context so that storage_request_context and + // memory_execution_context reflect the current request's headers rather + // than the stale values captured at startup. + let regular_ctx = ResponsesContext::new( + Arc::new(self.pipeline.clone()), + self.shared_components.clone(), + self.responses_context.persistence.clone(), + self.responses_context.mcp_orchestrator.clone(), + smg_data_connector::current_request_context(), + self.max_conversation_history_items, + memory_execution_context, + ); responses::route_responses( - &self.responses_context, + ®ular_ctx, Arc::new(body.clone()), headers.cloned(), tenant_meta.clone(), diff --git a/model_gateway/src/routers/openai/context.rs b/model_gateway/src/routers/openai/context.rs index 9e94ccd24..666161be5 100644 --- a/model_gateway/src/routers/openai/context.rs +++ b/model_gateway/src/routers/openai/context.rs @@ -14,7 +14,8 @@ use smg_mcp::{McpOrchestrator, McpToolSession}; use super::provider::Provider; use crate::{ config::RouterConfig, memory::MemoryExecutionContext, middleware, - middleware::TenantRequestMeta, worker::Worker, + middleware::TenantRequestMeta, routers::common::persistence_utils::ConversationTurnInfo, + worker::Worker, }; pub struct RequestContext { @@ -132,6 +133,7 @@ pub struct PayloadState { pub struct ResponsesPayloadState { pub previous_response_id: Option, pub existing_mcp_list_tools_labels: Vec, + pub conversation_turn_info: Option, } impl RequestContext { @@ -268,6 +270,7 @@ pub struct OwnedStreamingContext { pub original_body: ResponsesRequest, pub previous_response_id: Option, pub existing_mcp_list_tools_labels: Vec, + pub conversation_turn_info: Option, pub storage: StorageHandles, } @@ -306,6 +309,7 @@ impl RequestContext { original_body, previous_response_id: responses_payload_state.previous_response_id, existing_mcp_list_tools_labels: responses_payload_state.existing_mcp_list_tools_labels, + conversation_turn_info: responses_payload_state.conversation_turn_info, storage: StorageHandles { response, conversation, diff --git a/model_gateway/src/routers/openai/responses/history.rs b/model_gateway/src/routers/openai/responses/history.rs index b457e1c2c..b9a3051e8 100644 --- a/model_gateway/src/routers/openai/responses/history.rs +++ b/model_gateway/src/routers/openai/responses/history.rs @@ -19,17 +19,19 @@ use crate::{ observability::metrics::{metrics_labels, Metrics}, routers::{ common::{ - header_utils::ConversationMemoryConfig, persistence_utils::split_stored_message_content, + header_utils::ConversationMemoryConfig, + persistence_utils::{ + count_conversation_turn_info, split_stored_message_content, ConversationTurnInfo, + }, }, error, }, }; -const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100; - pub(crate) struct LoadedInputHistory { pub previous_response_id: Option, pub existing_mcp_list_tools_labels: Vec, + pub conversation_turn_info: Option, } /// Load conversation history and/or previous response chain into request input. @@ -41,6 +43,7 @@ pub(crate) async fn load_input_history( conversation: Option<&str>, request_body: &mut ResponsesRequest, model: &str, + stm_enabled: bool, ) -> Result { let previous_response_id = request_body .previous_response_id @@ -50,6 +53,12 @@ pub(crate) async fn load_input_history( // Load items from previous response chain if specified let mut chain_items: Option> = None; + let mut raw_stored_item_count: Option = None; + // Capture current request input count before history loading mutates request_body.input + let current_input_count = match &request_body.input { + ResponseInput::Text(_) => 1, + ResponseInput::Items(items) => items.len(), + }; if let Some(prev_id_str) = &previous_response_id { let prev_id = ResponseId::from(prev_id_str.as_str()); match components @@ -138,18 +147,48 @@ pub(crate) async fn load_input_history( )); } + // Fetch up to cap items. If we get cap rows back the conversation is at + // or beyond the configured limit; reject when STMO is active and the + // response will be persisted. The SQL LIMIT also bounds the inference + // window, so no post-fetch truncation is needed. + let cap = components + .shared + .router_config + .max_conversation_history_items; let params = ListParams { - limit: MAX_CONVERSATION_HISTORY_ITEMS, + limit: cap, order: SortOrder::Asc, after: None, }; - match components .conversation_item_storage .list_items(&conv_id, params) .await { Ok(stored_items) => { + // Only reject oversized conversations when the response will + // actually be persisted. store=false requests skip persistence + // entirely, so STMO is never enqueued and the cap does not apply. + if stm_enabled && request_body.store.unwrap_or(true) && stored_items.len() >= cap { + Metrics::record_router_error( + metrics_labels::ROUTER_OPENAI, + metrics_labels::BACKEND_EXTERNAL, + metrics_labels::CONNECTION_HTTP, + model, + metrics_labels::ENDPOINT_RESPONSES, + metrics_labels::ERROR_VALIDATION, + ); + return Err(error::bad_request( + "conversation_too_large", + format!( + "Conversation has reached the configured limit of {cap} history items. \ + Increase max_conversation_history_items in the router config \ + or reduce conversation length before using short-term memory \ + optimization.", + ), + )); + } + raw_stored_item_count = Some(stored_items.len()); let mut items: Vec = Vec::new(); for item in stored_items { match item.item_type.as_str() { @@ -232,9 +271,41 @@ pub(crate) async fn load_input_history( request_body.input = ResponseInput::Items(items); } + let conversation_turn_info = if stm_enabled { + // If a conversation was requested but list_items failed, the assembled + // input only contains the current request — STMO turn counts would be + // wrong. Skip STMO for this request so persistence does not enqueue a + // job with an undercounted last_index/target_item_end. + if conversation.is_some() && raw_stored_item_count.is_none() { + None + } else { + let mut info = count_conversation_turn_info(&request_body.input); + // Correct total_items when loaded from conversation storage: reasoning + // items are skipped during load but are stored in the DB, so the + // assembled input underestimates the true conversation size. Use the + // raw DB count + current input item count for an accurate target_item_end. + if let Some(raw_count) = raw_stored_item_count { + // Only apply the raw-count correction when no response chain + // was also loaded. If previous_response_id was set, the chain + // merge ran last and overwrote request_body.input with the full + // replayed history — count_conversation_turn_info already saw + // every item, so no correction is needed. (conversation and + // previous_response_id are mutually exclusive per the API spec, + // but we guard here for safety.) + if previous_response_id.is_none() { + info.total_items = raw_count + current_input_count; + } + } + Some(info) + } + } else { + None + }; + Ok(LoadedInputHistory { previous_response_id, existing_mcp_list_tools_labels: existing_mcp_list_tools_labels.into_iter().collect(), + conversation_turn_info, }) } diff --git a/model_gateway/src/routers/openai/responses/non_streaming.rs b/model_gateway/src/routers/openai/responses/non_streaming.rs index 8a9959c09..7db832cc6 100644 --- a/model_gateway/src/routers/openai/responses/non_streaming.rs +++ b/model_gateway/src/routers/openai/responses/non_streaming.rs @@ -41,6 +41,7 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response let ResponsesPayloadState { previous_response_id, existing_mcp_list_tools_labels, + conversation_turn_info, } = ctx.take_responses_payload().unwrap_or_default(); let original_body = match ctx.responses_request() { @@ -164,18 +165,22 @@ pub async fn handle_non_streaming_response(mut ctx: RequestContext) -> Response previous_response_id.as_deref(), ); - if let (Some(conv_storage), Some(item_storage), Some(resp_storage)) = ( + if let (Some(conv_storage), Some(item_storage), Some(memory_writer), Some(resp_storage)) = ( ctx.components.conversation_storage(), ctx.components.conversation_item_storage(), + ctx.components.conversation_memory_writer(), ctx.components.response_storage(), ) { if let Err(err) = persist_conversation_items( conv_storage.clone(), item_storage.clone(), + memory_writer.clone(), resp_storage.clone(), &response_json, original_body, ctx.storage_request_context.clone(), + ctx.memory_execution_context.clone(), + conversation_turn_info, ) .await { diff --git a/model_gateway/src/routers/openai/responses/route.rs b/model_gateway/src/routers/openai/responses/route.rs index cf131cae0..e3f145589 100644 --- a/model_gateway/src/routers/openai/responses/route.rs +++ b/model_gateway/src/routers/openai/responses/route.rs @@ -115,12 +115,23 @@ pub(in crate::routers::openai) async fn route_responses( let mut request_body = body.clone(); request_body.model = model_id.to_string(); request_body.conversation = None; + let memory_config = extract_conversation_memory_config(headers); + let stm_enabled = memory_config + .as_ref() + .is_some_and(|cfg| cfg.short_term_memory.enabled) + && deps + .responses_components + .shared + .router_config + .memory_runtime + .enabled; let loaded_history = match super::history::load_input_history( deps.responses_components, conversation.map(|c| c.as_id()), &mut request_body, model, + stm_enabled, ) .await { @@ -128,7 +139,7 @@ pub(in crate::routers::openai) async fn route_responses( Err(response) => return response, }; - if let Some(memory_config) = extract_conversation_memory_config(headers) { + if let Some(memory_config) = memory_config { super::history::inject_memory_context(&memory_config, &mut request_body); } @@ -189,6 +200,7 @@ pub(in crate::routers::openai) async fn route_responses( ctx.state.responses_payload = Some(ResponsesPayloadState { previous_response_id: loaded_history.previous_response_id, existing_mcp_list_tools_labels: loaded_history.existing_mcp_list_tools_labels, + conversation_turn_info: loaded_history.conversation_turn_info, }); let response = if ctx.is_streaming() { diff --git a/model_gateway/src/routers/openai/responses/streaming.rs b/model_gateway/src/routers/openai/responses/streaming.rs index e22e36cf8..3dcdbd2fb 100644 --- a/model_gateway/src/routers/openai/responses/streaming.rs +++ b/model_gateway/src/routers/openai/responses/streaming.rs @@ -562,6 +562,7 @@ pub(super) async fn handle_simple_streaming_passthrough( let should_store = req.original_body.store.unwrap_or(true); let original_request = req.original_body; let previous_response_id = req.previous_response_id; + let conversation_turn_info = req.conversation_turn_info; let storage = req.storage; #[expect( @@ -634,10 +635,13 @@ pub(super) async fn handle_simple_streaming_passthrough( if let Err(err) = persist_conversation_items( storage.conversation.clone(), storage.conversation_item.clone(), + storage.conversation_memory_writer.clone(), storage.response.clone(), &response_json, &original_request, storage.request_context.clone(), + storage.memory_execution_context.clone(), + conversation_turn_info, ) .await { @@ -683,6 +687,7 @@ pub(super) fn handle_streaming_with_tool_interception( let original_request = req.original_body; let previous_response_id = req.previous_response_id; let existing_mcp_list_tools_labels = req.existing_mcp_list_tools_labels; + let conversation_turn_info = req.conversation_turn_info; let url = req.url; let storage = req.storage; @@ -962,10 +967,13 @@ pub(super) fn handle_streaming_with_tool_interception( if let Err(err) = persist_conversation_items( storage.conversation.clone(), storage.conversation_item.clone(), + storage.conversation_memory_writer.clone(), storage.response.clone(), &response_json, &original_request, storage.request_context.clone(), + storage.memory_execution_context.clone(), + conversation_turn_info, ) .await {