diff --git a/src/codex_agent.rs b/src/codex_agent.rs index 306ffc9..1ad588a 100644 --- a/src/codex_agent.rs +++ b/src/codex_agent.rs @@ -124,6 +124,7 @@ impl Agent for CodexAgent { CodexAuthMethod::ChatGpt.into(), CodexAuthMethod::CodexApiKey.into(), CodexAuthMethod::OpenAiApiKey.into(), + CodexAuthMethod::CustomModelProvider.into(), ]; // Until codex device code auth works, we can't use this in remote ssh projects if std::env::var("NO_BROWSER").is_ok() { @@ -176,6 +177,11 @@ impl Agent for CodexAgent { codex_login::login_with_api_key(&self.config.codex_home, &api_key) .map_err(Error::into_internal_error)?; } + CodexAuthMethod::CustomModelProvider => { + // For custom model provider, we assume the user has already set up their credentials + // via environment variables or config files as needed. + info!("Assuming custom model provider is set up correctly."); + } } self.auth_manager.reload(); @@ -280,7 +286,11 @@ impl Agent for CodexAgent { self.auth_manager.clone(), self.client_capabilities.clone(), config.clone(), - self.model_presets.clone(), + crate::conversation::ConversationModelContext { + model_presets: self.model_presets.clone(), + model_provider: config.model_provider.clone(), + model: config.model.clone(), + }, )); let load = conversation.load().await?; @@ -377,6 +387,7 @@ enum CodexAuthMethod { ChatGpt, CodexApiKey, OpenAiApiKey, + CustomModelProvider, } impl From for AuthMethodId { @@ -386,6 +397,7 @@ impl From for AuthMethodId { CodexAuthMethod::ChatGpt => "chatgpt", CodexAuthMethod::CodexApiKey => "codex-api-key", CodexAuthMethod::OpenAiApiKey => "openai-api-key", + CodexAuthMethod::CustomModelProvider => "custom-model-provider", } .into(), ) @@ -420,6 +432,12 @@ impl From for AuthMethod { )), meta: None, }, + CodexAuthMethod::CustomModelProvider => Self { + id: method.into(), + name: "Use Custom Model Provider".into(), + description: Some("Requires setting up a custom model provider.".into()), + meta: None, + }, } } } @@ -432,6 +450,7 @@ impl TryFrom for CodexAuthMethod { "chatgpt" => Ok(CodexAuthMethod::ChatGpt), "codex-api-key" => Ok(CodexAuthMethod::CodexApiKey), "openai-api-key" => Ok(CodexAuthMethod::OpenAiApiKey), + "custom-model-provider" => Ok(CodexAuthMethod::CustomModelProvider), _ => Err(Error::invalid_params().with_data("unsupported authentication method")), } } diff --git a/src/conversation.rs b/src/conversation.rs index 96416f8..b18a6c5 100644 --- a/src/conversation.rs +++ b/src/conversation.rs @@ -23,7 +23,7 @@ use codex_common::{ model_presets::ModelPreset, }; use codex_core::{ - AuthManager, CodexConversation, + AuthManager, CodexConversation, ModelProviderInfo, config::Config, error::CodexErr, protocol::{ @@ -124,7 +124,7 @@ impl Conversation { auth: Arc, client_capabilities: Arc>, config: Config, - model_presets: Rc>, + model_ctx: ConversationModelContext, ) -> Self { let (message_tx, message_rx) = mpsc::unbounded_channel(); @@ -133,7 +133,7 @@ impl Conversation { SessionClient::new(session_id, client_capabilities), conversation.clone(), config, - model_presets, + model_ctx, message_rx, ); let handle = tokio::task::spawn_local(actor.spawn()); @@ -214,6 +214,17 @@ enum SubmissionState { Task(TaskState), } +/// Model-related context used when creating a conversation. +#[derive(Clone)] +pub(crate) struct ConversationModelContext { + /// Presets available for models by default. + pub(crate) model_presets: Rc>, + /// Information about the model provider in use. + pub(crate) model_provider: ModelProviderInfo, + /// The identifier/name of the currently selected model. + pub(crate) model: String, +} + impl SubmissionState { fn is_active(&self) -> bool { match self { @@ -281,6 +292,10 @@ struct PromptState { event_count: usize, response_tx: Option>>, submission_id: String, + // Whether we've already received streaming deltas for the agent message + saw_agent_message_delta: bool, + // Whether we've already received streaming deltas for the agent reasoning + saw_agent_reasoning_delta: bool, } impl PromptState { @@ -296,6 +311,8 @@ impl PromptState { event_count: 0, response_tx: Some(response_tx), submission_id, + saw_agent_message_delta: false, + saw_agent_reasoning_delta: false, } } @@ -356,17 +373,43 @@ impl PromptState { } EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => { // Send this to the client via session/update notification - info!("Agent message received: {delta:?}"); + info!("Agent message delta received: {delta:?}"); + // Mark that we've seen deltas for the agent message and send the chunk. + // When deltas are emitted we will skip the eventual non-streaming + // `AgentMessage` to avoid sending duplicated content to the client. + self.saw_agent_message_delta = true; client.send_agent_text(delta).await; } + EventMsg::AgentMessage(AgentMessageEvent { message, .. }) => { + // If we already received deltas for this submission, skip the final + // full message to avoid double-sending the same content. + if self.saw_agent_message_delta { + info!("Skipping final AgentMessage because deltas were streamed"); + } else { + info!("Agent message (non-delta) received: {message:?}"); + client.send_agent_text(message).await; + } + } EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) | EventMsg::AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent { delta, }) => { // Send this to the client via session/update notification - info!("Agent reasoning message received: {:?}", delta); + info!("Agent reasoning message delta received: {:?}", delta); + // Mark that we've seen deltas for agent reasoning and send the chunk. + // If reasoning deltas are used we will skip the non-streaming + // `AgentReasoning` event to avoid duplicating content. + self.saw_agent_reasoning_delta = true; client.send_agent_thought(delta).await; } + EventMsg::AgentReasoning(AgentReasoningEvent { text }) => { + if self.saw_agent_reasoning_delta { + info!("Skipping final AgentReasoning because deltas were streamed"); + } else { + info!("Agent reasoning (non-delta) received: {:?}", text); + client.send_agent_thought(text).await; + } + } EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}) => { // Make sure the section heading actually get spacing client.send_agent_thought("\n\n").await; @@ -507,9 +550,7 @@ impl PromptState { } // Since we are getting the deltas, we can ignore these events - EventMsg::AgentReasoning(..) - | EventMsg::AgentReasoningRawContent(..) - | EventMsg::AgentMessage(..) + EventMsg::AgentReasoningRawContent(..) // In the future we can use this to update usage stats | EventMsg::TokenCount(..) // we already have a way to diff the turn, so ignore @@ -1340,8 +1381,8 @@ struct ConversationActor { config: Config, /// The custom prompts loaded for this workspace. custom_prompts: Rc>>, - /// The model presets for the conversation. - model_presets: Rc>, + /// Model-related context (presets, provider, current model) + model_ctx: ConversationModelContext, /// A sender for each interested `Op` submission that needs events routed. submissions: HashMap, /// A receiver for incoming conversation messages. @@ -1354,7 +1395,7 @@ impl ConversationActor { client: SessionClient, conversation: Arc, config: Config, - model_presets: Rc>, + model_ctx: ConversationModelContext, message_rx: mpsc::UnboundedReceiver, ) -> Self { Self { @@ -1363,7 +1404,7 @@ impl ConversationActor { conversation, config, custom_prompts: Rc::default(), - model_presets, + model_ctx, submissions: HashMap::new(), message_rx, } @@ -1546,7 +1587,7 @@ impl ConversationActor { } fn find_model_preset(&self) -> Option<&ModelPreset> { - if let Some(preset) = self.model_presets.iter().find(|preset| { + if let Some(preset) = self.model_ctx.model_presets.iter().find(|preset| { preset.model == self.config.model && preset.effort == self.config.model_reasoning_effort }) { return Some(preset); @@ -1554,7 +1595,7 @@ impl ConversationActor { // If we didn't find it, and it is set to none, see if we can find one with the default value if self.config.model_reasoning_effort.is_none() - && let Some(preset) = self.model_presets.iter().find(|preset| { + && let Some(preset) = self.model_ctx.model_presets.iter().find(|preset| { preset.model == self.config.model && preset.effort == Some(ReasoningEffort::default()) }) @@ -1569,26 +1610,38 @@ impl ConversationActor { let current_model_id = self .find_model_preset() .map(|preset| ModelId(preset.id.into())) - .ok_or_else(|| { - anyhow::anyhow!("No valid model preset for model {}", self.config.model) - })?; + // Fallback to configured model id if no preset found to avoid failing session load. + .unwrap_or_else(|| ModelId(self.config.model.clone().into())); - let available_models = self - .model_presets - .iter() - .map(|preset| ModelInfo { - model_id: ModelId(preset.id.into()), - name: preset.label.into(), - description: Some( - preset - .description - .strip_prefix("— ") - .unwrap_or(preset.description) - .into(), + let available_models: Vec = if self.model_ctx.model_provider.name.is_empty() { + self.model_ctx + .model_presets + .iter() + .map(|preset| ModelInfo { + model_id: ModelId(preset.id.into()), + name: preset.label.into(), + description: Some( + preset + .description + .strip_prefix("— ") + .unwrap_or(preset.description) + .into(), + ), + meta: None, + }) + .collect() + } else { + vec![ModelInfo { + model_id: ModelId(self.model_ctx.model.clone().into()), + name: format!( + "{}/{}", + self.model_ctx.model_provider.name, + self.model_ctx.model.as_str() ), + description: Some("Configured model provided by the model provider".into()), meta: None, - }) - .collect(); + }] + }; Ok(SessionModelState { current_model_id, @@ -1735,6 +1788,7 @@ impl ConversationActor { async fn handle_set_model(&mut self, model: ModelId) -> Result<(), Error> { let preset = self + .model_ctx .model_presets .iter() .find(|p| p.id == model.0.as_ref()) @@ -2422,7 +2476,24 @@ mod tests { session_client, conversation.clone(), config, - Default::default(), + ConversationModelContext { + model_presets: Default::default(), + model_provider: ModelProviderInfo { + name: "test".to_string(), + base_url: Some("https://api.test.com".to_string()), + env_key: None, + env_key_instructions: None, + wire_api: Default::default(), + query_params: Default::default(), + http_headers: Default::default(), + env_http_headers: Default::default(), + request_max_retries: Default::default(), + stream_max_retries: Default::default(), + stream_idle_timeout_ms: Default::default(), + requires_openai_auth: false, + }, + model: String::from("gpt-5"), + }, message_rx, ); actor.custom_prompts = Rc::new(RefCell::new(custom_prompts));