Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/codex_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ impl Agent for CodexAgent {
CodexAuthMethod::ChatGpt.into(),
CodexAuthMethod::CodexApiKey.into(),
CodexAuthMethod::OpenAiApiKey.into(),
CodexAuthMethod::CustomModelProvider.into(),
];
// Until codex device code auth works, we can't use this in remote ssh projects
if std::env::var("NO_BROWSER").is_ok() {
Expand Down Expand Up @@ -176,6 +177,11 @@ impl Agent for CodexAgent {
codex_login::login_with_api_key(&self.config.codex_home, &api_key)
.map_err(Error::into_internal_error)?;
}
CodexAuthMethod::CustomModelProvider => {
// For custom model provider, we assume the user has already set up their credentials
// via environment variables or config files as needed.
info!("Assuming custom model provider is set up correctly.");
}
}

self.auth_manager.reload();
Expand Down Expand Up @@ -280,7 +286,11 @@ impl Agent for CodexAgent {
self.auth_manager.clone(),
self.client_capabilities.clone(),
config.clone(),
self.model_presets.clone(),
crate::conversation::ConversationModelContext {
model_presets: self.model_presets.clone(),
model_provider: config.model_provider.clone(),
model: config.model.clone(),
},
));
let load = conversation.load().await?;

Expand Down Expand Up @@ -377,6 +387,7 @@ enum CodexAuthMethod {
ChatGpt,
CodexApiKey,
OpenAiApiKey,
CustomModelProvider,
}

impl From<CodexAuthMethod> for AuthMethodId {
Expand All @@ -386,6 +397,7 @@ impl From<CodexAuthMethod> for AuthMethodId {
CodexAuthMethod::ChatGpt => "chatgpt",
CodexAuthMethod::CodexApiKey => "codex-api-key",
CodexAuthMethod::OpenAiApiKey => "openai-api-key",
CodexAuthMethod::CustomModelProvider => "custom-model-provider",
}
.into(),
)
Expand Down Expand Up @@ -420,6 +432,12 @@ impl From<CodexAuthMethod> for AuthMethod {
)),
meta: None,
},
CodexAuthMethod::CustomModelProvider => Self {
id: method.into(),
name: "Use Custom Model Provider".into(),
description: Some("Requires setting up a custom model provider.".into()),
meta: None,
},
}
}
}
Expand All @@ -432,6 +450,7 @@ impl TryFrom<AuthMethodId> for CodexAuthMethod {
"chatgpt" => Ok(CodexAuthMethod::ChatGpt),
"codex-api-key" => Ok(CodexAuthMethod::CodexApiKey),
"openai-api-key" => Ok(CodexAuthMethod::OpenAiApiKey),
"custom-model-provider" => Ok(CodexAuthMethod::CustomModelProvider),
_ => Err(Error::invalid_params().with_data("unsupported authentication method")),
}
}
Expand Down
135 changes: 103 additions & 32 deletions src/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use codex_common::{
model_presets::ModelPreset,
};
use codex_core::{
AuthManager, CodexConversation,
AuthManager, CodexConversation, ModelProviderInfo,
config::Config,
error::CodexErr,
protocol::{
Expand Down Expand Up @@ -124,7 +124,7 @@ impl Conversation {
auth: Arc<AuthManager>,
client_capabilities: Arc<Mutex<ClientCapabilities>>,
config: Config,
model_presets: Rc<Vec<ModelPreset>>,
model_ctx: ConversationModelContext,
) -> Self {
let (message_tx, message_rx) = mpsc::unbounded_channel();

Expand All @@ -133,7 +133,7 @@ impl Conversation {
SessionClient::new(session_id, client_capabilities),
conversation.clone(),
config,
model_presets,
model_ctx,
message_rx,
);
let handle = tokio::task::spawn_local(actor.spawn());
Expand Down Expand Up @@ -214,6 +214,17 @@ enum SubmissionState {
Task(TaskState),
}

/// Model-related context used when creating a conversation.
#[derive(Clone)]
pub(crate) struct ConversationModelContext {
/// Presets available for models by default.
pub(crate) model_presets: Rc<Vec<ModelPreset>>,
/// Information about the model provider in use.
pub(crate) model_provider: ModelProviderInfo,
/// The identifier/name of the currently selected model.
pub(crate) model: String,
}

impl SubmissionState {
fn is_active(&self) -> bool {
match self {
Expand Down Expand Up @@ -281,6 +292,10 @@ struct PromptState {
event_count: usize,
response_tx: Option<oneshot::Sender<Result<StopReason, Error>>>,
submission_id: String,
// Whether we've already received streaming deltas for the agent message
saw_agent_message_delta: bool,
// Whether we've already received streaming deltas for the agent reasoning
saw_agent_reasoning_delta: bool,
}

impl PromptState {
Expand All @@ -296,6 +311,8 @@ impl PromptState {
event_count: 0,
response_tx: Some(response_tx),
submission_id,
saw_agent_message_delta: false,
saw_agent_reasoning_delta: false,
}
}

Expand Down Expand Up @@ -356,17 +373,43 @@ impl PromptState {
}
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
// Send this to the client via session/update notification
info!("Agent message received: {delta:?}");
info!("Agent message delta received: {delta:?}");
// Mark that we've seen deltas for the agent message and send the chunk.
// When deltas are emitted we will skip the eventual non-streaming
// `AgentMessage` to avoid sending duplicated content to the client.
self.saw_agent_message_delta = true;
client.send_agent_text(delta).await;
}
EventMsg::AgentMessage(AgentMessageEvent { message, .. }) => {
// If we already received deltas for this submission, skip the final
// full message to avoid double-sending the same content.
if self.saw_agent_message_delta {
info!("Skipping final AgentMessage because deltas were streamed");
} else {
info!("Agent message (non-delta) received: {message:?}");
client.send_agent_text(message).await;
}
}
EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta })
| EventMsg::AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent {
delta,
}) => {
// Send this to the client via session/update notification
info!("Agent reasoning message received: {:?}", delta);
info!("Agent reasoning message delta received: {:?}", delta);
// Mark that we've seen deltas for agent reasoning and send the chunk.
// If reasoning deltas are used we will skip the non-streaming
// `AgentReasoning` event to avoid duplicating content.
self.saw_agent_reasoning_delta = true;
client.send_agent_thought(delta).await;
}
EventMsg::AgentReasoning(AgentReasoningEvent { text }) => {
if self.saw_agent_reasoning_delta {
info!("Skipping final AgentReasoning because deltas were streamed");
} else {
info!("Agent reasoning (non-delta) received: {:?}", text);
client.send_agent_thought(text).await;
}
}
EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}) => {
// Make sure the section heading actually get spacing
client.send_agent_thought("\n\n").await;
Expand Down Expand Up @@ -507,9 +550,7 @@ impl PromptState {
}

// Since we are getting the deltas, we can ignore these events
EventMsg::AgentReasoning(..)
| EventMsg::AgentReasoningRawContent(..)
| EventMsg::AgentMessage(..)
EventMsg::AgentReasoningRawContent(..)
// In the future we can use this to update usage stats
| EventMsg::TokenCount(..)
// we already have a way to diff the turn, so ignore
Expand Down Expand Up @@ -1340,8 +1381,8 @@ struct ConversationActor<A> {
config: Config,
/// The custom prompts loaded for this workspace.
custom_prompts: Rc<RefCell<Vec<CustomPrompt>>>,
/// The model presets for the conversation.
model_presets: Rc<Vec<ModelPreset>>,
/// Model-related context (presets, provider, current model)
model_ctx: ConversationModelContext,
/// A sender for each interested `Op` submission that needs events routed.
submissions: HashMap<String, SubmissionState>,
/// A receiver for incoming conversation messages.
Expand All @@ -1354,7 +1395,7 @@ impl<A: Auth> ConversationActor<A> {
client: SessionClient,
conversation: Arc<dyn CodexConversationImpl>,
config: Config,
model_presets: Rc<Vec<ModelPreset>>,
model_ctx: ConversationModelContext,
message_rx: mpsc::UnboundedReceiver<ConversationMessage>,
) -> Self {
Self {
Expand All @@ -1363,7 +1404,7 @@ impl<A: Auth> ConversationActor<A> {
conversation,
config,
custom_prompts: Rc::default(),
model_presets,
model_ctx,
submissions: HashMap::new(),
message_rx,
}
Expand Down Expand Up @@ -1546,15 +1587,15 @@ impl<A: Auth> ConversationActor<A> {
}

fn find_model_preset(&self) -> Option<&ModelPreset> {
if let Some(preset) = self.model_presets.iter().find(|preset| {
if let Some(preset) = self.model_ctx.model_presets.iter().find(|preset| {
preset.model == self.config.model && preset.effort == self.config.model_reasoning_effort
}) {
return Some(preset);
}

// If we didn't find it, and it is set to none, see if we can find one with the default value
if self.config.model_reasoning_effort.is_none()
&& let Some(preset) = self.model_presets.iter().find(|preset| {
&& let Some(preset) = self.model_ctx.model_presets.iter().find(|preset| {
preset.model == self.config.model
&& preset.effort == Some(ReasoningEffort::default())
})
Expand All @@ -1569,26 +1610,38 @@ impl<A: Auth> ConversationActor<A> {
let current_model_id = self
.find_model_preset()
.map(|preset| ModelId(preset.id.into()))
.ok_or_else(|| {
anyhow::anyhow!("No valid model preset for model {}", self.config.model)
})?;
// Fallback to configured model id if no preset found to avoid failing session load.
.unwrap_or_else(|| ModelId(self.config.model.clone().into()));

let available_models = self
.model_presets
.iter()
.map(|preset| ModelInfo {
model_id: ModelId(preset.id.into()),
name: preset.label.into(),
description: Some(
preset
.description
.strip_prefix("— ")
.unwrap_or(preset.description)
.into(),
let available_models: Vec<ModelInfo> = if self.model_ctx.model_provider.name.is_empty() {
self.model_ctx
.model_presets
.iter()
.map(|preset| ModelInfo {
model_id: ModelId(preset.id.into()),
name: preset.label.into(),
description: Some(
preset
.description
.strip_prefix("— ")
.unwrap_or(preset.description)
.into(),
),
meta: None,
})
.collect()
} else {
vec![ModelInfo {
model_id: ModelId(self.model_ctx.model.clone().into()),
name: format!(
"{}/{}",
self.model_ctx.model_provider.name,
self.model_ctx.model.as_str()
),
description: Some("Configured model provided by the model provider".into()),
meta: None,
})
.collect();
}]
};

Ok(SessionModelState {
current_model_id,
Expand Down Expand Up @@ -1735,6 +1788,7 @@ impl<A: Auth> ConversationActor<A> {

async fn handle_set_model(&mut self, model: ModelId) -> Result<(), Error> {
let preset = self
.model_ctx
.model_presets
.iter()
.find(|p| p.id == model.0.as_ref())
Expand Down Expand Up @@ -2422,7 +2476,24 @@ mod tests {
session_client,
conversation.clone(),
config,
Default::default(),
ConversationModelContext {
model_presets: Default::default(),
model_provider: ModelProviderInfo {
name: "test".to_string(),
base_url: Some("https://api.test.com".to_string()),
env_key: None,
env_key_instructions: None,
wire_api: Default::default(),
query_params: Default::default(),
http_headers: Default::default(),
env_http_headers: Default::default(),
request_max_retries: Default::default(),
stream_max_retries: Default::default(),
stream_idle_timeout_ms: Default::default(),
requires_openai_auth: false,
},
model: String::from("gpt-5"),
},
message_rx,
);
actor.custom_prompts = Rc::new(RefCell::new(custom_prompts));
Expand Down