diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 1443b2388..d353b016c 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -760,6 +760,7 @@ dependencies = [ "json-five", "json5", "log", + "mac_address", "objc2 0.5.2", "objc2-app-kit 0.2.2", "once_cell", @@ -2882,6 +2883,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" +[[package]] +name = "mac_address" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0aeb26bf5e836cc1c341c8106051b573f1766dfa05aa87f0b98be5e51b02303" +dependencies = [ + "nix", + "winapi", +] + [[package]] name = "markup5ever" version = "0.14.1" @@ -3062,6 +3073,19 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "cfg_aliases", + "libc", + "memoffset", +] + [[package]] name = "nodrop" version = "0.1.14" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index e62bc1e36..94fdad7a9 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -76,6 +76,7 @@ indexmap = { version = "2", features = ["serde"] } rust_decimal = "1.33" uuid = { version = "1.11", features = ["v4"] } sha2 = "0.10" +mac_address = "1.1" json5 = "0.4" json-five = "0.3.1" diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 7f17c0b43..c8ef42b17 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -794,8 +794,17 @@ pub fn run() { use tokio::sync::RwLock; let app_config_dir = crate::config::get_app_config_dir(); - let copilot_auth_manager = CopilotAuthManager::new(app_config_dir); - app.manage(CopilotAuthState(Arc::new(RwLock::new(copilot_auth_manager)))); + let copilot_auth_manager = Arc::new(RwLock::new(CopilotAuthManager::new( + app_config_dir, + ))); + app.manage(CopilotAuthState(Arc::clone(&copilot_auth_manager))); + tauri::async_runtime::spawn(async move { + copilot_auth_manager + .read() + .await + .initialize_background_tasks() + .await; + }); log::info!("✓ CopilotAuthManager initialized"); } diff --git a/src-tauri/src/proxy/forwarder.rs b/src-tauri/src/proxy/forwarder.rs index c4172ea3d..67e692be7 100644 --- a/src-tauri/src/proxy/forwarder.rs +++ b/src-tauri/src/proxy/forwarder.rs @@ -38,6 +38,20 @@ pub struct ForwardError { pub provider: Option, } +pub struct ForwardRequestInput { + pub endpoint: String, + pub body: Value, + pub headers: axum::http::HeaderMap, + pub extensions: Extensions, + pub client_session_id: Option, +} + +struct ForwardMetadata<'a> { + headers: &'a axum::http::HeaderMap, + extensions: &'a Extensions, + client_session_id: Option<&'a str>, +} + pub struct RequestForwarder { /// 共享的 ProviderRouter(持有熔断器状态) router: Arc, @@ -100,12 +114,16 @@ impl RequestForwarder { pub async fn forward_with_retry( &self, app_type: &AppType, - endpoint: &str, - body: Value, - headers: axum::http::HeaderMap, - extensions: Extensions, + request: ForwardRequestInput, providers: Vec, ) -> Result { + let ForwardRequestInput { + endpoint, + body, + headers, + extensions, + client_session_id, + } = request; // 获取适配器 let adapter = get_adapter(app_type); let app_type_str = app_type.as_str(); @@ -177,10 +195,13 @@ impl RequestForwarder { match self .forward( provider, - endpoint, + &endpoint, &provider_body, - &headers, - &extensions, + ForwardMetadata { + headers: &headers, + extensions: &extensions, + client_session_id: client_session_id.as_deref(), + }, adapter.as_ref(), ) .await @@ -307,10 +328,13 @@ impl RequestForwarder { match self .forward( provider, - endpoint, + &endpoint, &provider_body, - &headers, - &extensions, + ForwardMetadata { + headers: &headers, + extensions: &extensions, + client_session_id: client_session_id.as_deref(), + }, adapter.as_ref(), ) .await @@ -506,10 +530,13 @@ impl RequestForwarder { match self .forward( provider, - endpoint, + &endpoint, &provider_body, - &headers, - &extensions, + ForwardMetadata { + headers: &headers, + extensions: &extensions, + client_session_id: client_session_id.as_deref(), + }, adapter.as_ref(), ) .await @@ -747,14 +774,16 @@ impl RequestForwarder { provider: &Provider, endpoint: &str, body: &Value, - headers: &axum::http::HeaderMap, - extensions: &Extensions, + metadata: ForwardMetadata<'_>, adapter: &dyn ProviderAdapter, ) -> Result<(ProxyResponse, Option), ProxyError> { + let headers = metadata.headers; + let extensions = metadata.extensions; + let client_session_id = metadata.client_session_id; // 使用适配器提取 base_url let mut base_url = adapter.extract_base_url(provider)?; - let is_full_url = provider + let has_explicit_full_url = provider .meta .as_ref() .and_then(|meta| meta.is_full_url) @@ -831,36 +860,6 @@ impl RequestForwarder { } else { None }; - - // GitHub Copilot 动态 endpoint 路由 - // 从 CopilotAuthManager 获取缓存的 API endpoint(支持企业版等非默认 endpoint) - if is_copilot && !is_full_url { - if let Some(app_handle) = &self.app_handle { - let copilot_state = app_handle.state::(); - let copilot_auth = copilot_state.0.read().await; - - // 从 provider.meta 获取关联的 GitHub 账号 ID - let account_id = provider - .meta - .as_ref() - .and_then(|m| m.managed_account_id_for("github_copilot")); - - let dynamic_endpoint = match &account_id { - Some(id) => copilot_auth.get_api_endpoint(id).await, - None => copilot_auth.get_default_api_endpoint().await, - }; - - // 只在动态 endpoint 与当前 base_url 不同时替换 - if dynamic_endpoint != base_url { - log::debug!( - "[Copilot] 使用动态 API endpoint: {} (原: {})", - dynamic_endpoint, - base_url - ); - base_url = dynamic_endpoint; - } - } - } let resolved_claude_api_format = if adapter.name() == "Claude" { Some( self.resolve_claude_api_format(provider, &mapped_body, is_copilot) @@ -888,6 +887,34 @@ impl RequestForwarder { ) }; + let is_full_url = has_explicit_full_url + || is_legacy_full_url_for_endpoint(&base_url, &effective_endpoint); + // GitHub Copilot 动态 endpoint 路由 + // 从 CopilotAuthManager 获取缓存的 API endpoint(支持企业版等非默认 endpoint) + if is_copilot && !is_full_url { + if let Some(app_handle) = &self.app_handle { + let copilot_state = app_handle.state::(); + let copilot_auth = copilot_state.0.read().await; + + // 从 provider.meta 获取关联的 GitHub 账号 ID + let account_id = provider + .meta + .as_ref() + .and_then(|m| m.managed_account_id_for("github_copilot")); + + let dynamic_endpoint = match &account_id { + Some(id) => copilot_auth.get_api_endpoint(id).await, + None => copilot_auth.get_default_api_endpoint().await, + }; + + // 只在动态 endpoint 与当前 base_url 不同时替换 + if dynamic_endpoint != base_url { + log::debug!("[Copilot] 已启用动态 API endpoint"); + base_url = dynamic_endpoint; + } + } + } + let url = if is_full_url { append_query_to_full_url(&base_url, passthrough_query.as_deref()) } else { @@ -950,10 +977,27 @@ impl RequestForwarder { match token_result { Ok(token) => { - auth = AuthInfo::new(token, AuthStrategy::GitHubCopilot); + // 获取 machine ID 和 session ID + let (machine_id, session_id) = copilot_auth + .get_account_ids_for_conversation( + account_id.as_deref(), + client_session_id, + ) + .await; + + // 使用新的构造函数 + auth = AuthInfo::new_with_ids( + token, + AuthStrategy::GitHubCopilot, + machine_id.clone(), + session_id.clone(), + ); + log::debug!( - "[Copilot] 成功获取 Copilot token (account={})", - account_id.as_deref().unwrap_or("default") + "[Copilot] 成功获取 Copilot token (account={}, has_machine_id={}, has_session_id={})", + account_id.as_deref().unwrap_or("default"), + machine_id.is_some(), + session_id.is_some(), ); } Err(e) => { @@ -1026,7 +1070,7 @@ impl RequestForwarder { } } - adapter.get_auth_headers(&auth) + adapter.get_auth_headers(&auth, Some(&filtered_body)) } else { Vec::new() }; @@ -1072,6 +1116,9 @@ impl RequestForwarder { "x-vscode-user-agent-library-version", "x-request-id", "x-agent-task-id", + // Machine ID 和 Session ID + "vscode-machineid", + "vscode-sessionid", ] } else { &[] @@ -1637,6 +1684,23 @@ fn append_query_to_full_url(base_url: &str, query: Option<&str>) -> String { } } +fn is_legacy_full_url_for_endpoint(base_url: &str, endpoint: &str) -> bool { + let Ok(parsed) = url::Url::parse(base_url) else { + return false; + }; + + let path = parsed.path().trim_end_matches('/'); + let endpoint_path = split_endpoint_and_query(endpoint).0.trim_end_matches('/'); + + if path.is_empty() || path == "/" || endpoint_path.is_empty() || endpoint_path == "/" { + return false; + } + + let normalized_endpoint_path = format!("/{}", endpoint_path.trim_start_matches('/')); + + path == endpoint_path || path.ends_with(&normalized_endpoint_path) +} + fn should_force_identity_encoding( endpoint: &str, body: &Value, @@ -1945,4 +2009,24 @@ mod tests { assert_eq!(will_replace, should_replace, "{desc}"); } } + + #[test] + fn legacy_full_url_detection_matches_endpoint_path() { + assert!(is_legacy_full_url_for_endpoint( + "https://proxy.example.com/chat/completions", + "/chat/completions" + )); + assert!(is_legacy_full_url_for_endpoint( + "https://proxy.example.com/api/v1/messages", + "/v1/messages" + )); + assert!(!is_legacy_full_url_for_endpoint( + "https://proxy.example.com/api", + "/chat/completions" + )); + assert!(!is_legacy_full_url_for_endpoint( + "https://proxy.example.com/v1", + "/v1/messages" + )); + } } diff --git a/src-tauri/src/proxy/handler_context.rs b/src-tauri/src/proxy/handler_context.rs index bad855a91..a90302ca5 100644 --- a/src-tauri/src/proxy/handler_context.rs +++ b/src-tauri/src/proxy/handler_context.rs @@ -57,6 +57,8 @@ pub struct RequestContext { pub app_type: AppType, /// Session ID(从客户端请求提取或新生成) pub session_id: String, + /// 客户端显式提供的 Session ID(仅当请求携带对话连续性线索时存在) + pub client_session_id: Option, /// 整流器配置 pub rectifier_config: RectifierConfig, /// 优化器配置 @@ -113,6 +115,9 @@ impl RequestContext { // 提取 Session ID let session_result = extract_session_id(headers, body, app_type_str); let session_id = session_result.session_id.clone(); + let client_session_id = session_result + .client_provided + .then(|| session_result.session_id.clone()); log::debug!( "[{}] Session ID: {} (from {:?}, client_provided: {})", @@ -161,6 +166,7 @@ impl RequestContext { app_type_str, app_type, session_id, + client_session_id, rectifier_config, optimizer_config, copilot_optimizer_config, diff --git a/src-tauri/src/proxy/handlers.rs b/src-tauri/src/proxy/handlers.rs index 08816246a..299e2abd5 100644 --- a/src-tauri/src/proxy/handlers.rs +++ b/src-tauri/src/proxy/handlers.rs @@ -28,6 +28,7 @@ use super::{ ProxyError, }; use crate::app_config::AppType; +use crate::proxy::forwarder::ForwardRequestInput; use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; use bytes::Bytes; use http_body_util::BodyExt; @@ -97,10 +98,13 @@ pub async fn handle_messages( let result = match forwarder .forward_with_retry( &AppType::Claude, - endpoint, - body.clone(), - headers, - extensions, + ForwardRequestInput { + endpoint: endpoint.to_string(), + body: body.clone(), + headers, + extensions, + client_session_id: ctx.client_session_id.clone(), + }, ctx.get_providers(), ) .await @@ -347,10 +351,13 @@ pub async fn handle_chat_completions( let result = match forwarder .forward_with_retry( &AppType::Codex, - &endpoint, - body, - headers, - extensions, + ForwardRequestInput { + endpoint, + body, + headers, + extensions, + client_session_id: ctx.client_session_id.clone(), + }, ctx.get_providers(), ) .await @@ -401,10 +408,13 @@ pub async fn handle_responses( let result = match forwarder .forward_with_retry( &AppType::Codex, - &endpoint, - body, - headers, - extensions, + ForwardRequestInput { + endpoint, + body, + headers, + extensions, + client_session_id: ctx.client_session_id.clone(), + }, ctx.get_providers(), ) .await @@ -455,10 +465,13 @@ pub async fn handle_responses_compact( let result = match forwarder .forward_with_retry( &AppType::Codex, - &endpoint, - body, - headers, - extensions, + ForwardRequestInput { + endpoint, + body, + headers, + extensions, + client_session_id: ctx.client_session_id.clone(), + }, ctx.get_providers(), ) .await @@ -520,10 +533,13 @@ pub async fn handle_gemini( let result = match forwarder .forward_with_retry( &AppType::Gemini, - endpoint, - body, - headers, - extensions, + ForwardRequestInput { + endpoint: endpoint.to_string(), + body, + headers, + extensions, + client_session_id: ctx.client_session_id.clone(), + }, ctx.get_providers(), ) .await diff --git a/src-tauri/src/proxy/providers/adapter.rs b/src-tauri/src/proxy/providers/adapter.rs index d764d3eb2..2c500306c 100644 --- a/src-tauri/src/proxy/providers/adapter.rs +++ b/src-tauri/src/proxy/providers/adapter.rs @@ -30,7 +30,15 @@ pub trait ProviderAdapter: Send + Sync { /// /// The forwarder inserts these at the position of the original auth header /// so that header order is preserved. - fn get_auth_headers(&self, auth: &AuthInfo) -> Vec<(http::HeaderName, http::HeaderValue)>; + /// + /// # Parameters + /// - `auth`: Authentication information + /// - `request_body`: Optional request body for generating deterministic request IDs (used by Copilot) + fn get_auth_headers( + &self, + auth: &AuthInfo, + request_body: Option<&Value>, + ) -> Vec<(http::HeaderName, http::HeaderValue)>; /// 是否需要格式转换 fn needs_transform(&self, _provider: &Provider) -> bool { diff --git a/src-tauri/src/proxy/providers/auth.rs b/src-tauri/src/proxy/providers/auth.rs index 6bd9a02a5..66f376dd7 100644 --- a/src-tauri/src/proxy/providers/auth.rs +++ b/src-tauri/src/proxy/providers/auth.rs @@ -13,6 +13,10 @@ pub struct AuthInfo { pub strategy: AuthStrategy, /// OAuth access_token(用于 GoogleOAuth 策略) pub access_token: Option, + /// Machine ID(用于 GitHub Copilot) + pub machine_id: Option, + /// Session ID(用于 GitHub Copilot) + pub session_id: Option, } impl AuthInfo { @@ -22,6 +26,8 @@ impl AuthInfo { api_key, strategy, access_token: None, + machine_id: None, + session_id: None, } } @@ -31,6 +37,24 @@ impl AuthInfo { api_key, strategy: AuthStrategy::GoogleOAuth, access_token: Some(access_token), + machine_id: None, + session_id: None, + } + } + + /// 创建带有 machine_id 和 session_id 的认证信息(用于 GitHub Copilot) + pub fn new_with_ids( + api_key: String, + strategy: AuthStrategy, + machine_id: Option, + session_id: Option, + ) -> Self { + Self { + api_key, + strategy, + access_token: None, + machine_id, + session_id, } } diff --git a/src-tauri/src/proxy/providers/claude.rs b/src-tauri/src/proxy/providers/claude.rs index 52689d3b5..5d3e0e757 100644 --- a/src-tauri/src/proxy/providers/claude.rs +++ b/src-tauri/src/proxy/providers/claude.rs @@ -16,6 +16,164 @@ use super::{AuthInfo, AuthStrategy, ProviderAdapter, ProviderType}; use crate::provider::Provider; use crate::proxy::error::ProxyError; +use serde_json::Value; + +/// 为每个上游请求生成唯一 request ID。 +fn generate_request_id() -> String { + uuid::Uuid::new_v4().to_string() +} + +/// 判断当前 Copilot 请求是否属于 agent continuation / tool loop。 +fn is_agent_initiated(request_body: Option<&Value>) -> bool { + let Some(body) = request_body else { + return false; + }; + + if let Some(messages) = body.get("messages").and_then(|msgs| msgs.as_array()) { + if messages.is_empty() { + return false; + } + + let last_index = match messages.iter().rposition(|msg| msg.get("role").is_some()) { + Some(index) => index, + None => return false, + }; + let last_message = &messages[last_index]; + let last_role = last_message + .get("role") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + + if matches!(last_role, "assistant" | "tool") { + return true; + } + + if last_role == "user" { + if last_message_contains_tool_result(last_message) { + return true; + } + + if last_index >= 1 && previous_message_contains_tool_use(&messages[last_index - 1]) { + return true; + } + } + + return false; + } + + if let Some(inputs) = body.get("input").and_then(|input| input.as_array()) { + if inputs.is_empty() { + return false; + } + + let last_index = inputs.len() - 1; + let last = &inputs[last_index]; + if last + .get("role") + .and_then(|value| value.as_str()) + .is_some_and(|role| role == "assistant") + { + return true; + } + + if input_item_is_agent_continuation(last) { + return true; + } + + if last + .get("role") + .and_then(|value| value.as_str()) + .is_some_and(|role| role == "user") + && last_message_contains_tool_result(last) + { + return true; + } + + if last_index >= 1 { + let previous = &inputs[last_index - 1]; + if input_item_requests_tool_continuation(previous) { + return last_message_contains_tool_result(last) + || input_item_is_agent_continuation(last); + } + } + + return false; + } + + false +} + +fn last_message_contains_tool_result(message: &Value) -> bool { + message + .get("content") + .and_then(|content| content.as_array()) + .is_some_and(|parts| { + parts.iter().any(|part| { + part.get("type") + .and_then(|value| value.as_str()) + .is_some_and(|part_type| part_type == "tool_result") + }) + }) +} + +fn previous_message_contains_tool_use(message: &Value) -> bool { + if message + .get("role") + .and_then(|value| value.as_str()) + .is_none_or(|role| role != "assistant") + { + return false; + } + + message + .get("content") + .and_then(|content| content.as_array()) + .is_some_and(|parts| { + parts.iter().any(|part| { + part.get("type") + .and_then(|value| value.as_str()) + .is_some_and(|part_type| part_type == "tool_use") + }) + }) +} + +fn input_item_is_agent_continuation(item: &Value) -> bool { + if item + .get("role") + .and_then(|value| value.as_str()) + .is_some_and(|role| role == "assistant") + { + return true; + } + + matches!( + item.get("type").and_then(|value| value.as_str()), + Some( + "function_call" + | "function_call_arguments" + | "computer_call" + | "function_call_output" + | "function_call_response" + | "tool_result" + | "computer_call_output" + ) + ) +} + +fn input_item_requests_tool_continuation(item: &Value) -> bool { + if item + .get("role") + .and_then(|value| value.as_str()) + .is_some_and(|role| role == "assistant") + { + return true; + } + + matches!( + item.get("type").and_then(|value| value.as_str()), + Some("function_call" | "function_call_arguments" | "computer_call") + ) +} /// 获取 Claude 供应商的 API 格式 /// @@ -390,7 +548,11 @@ impl ProviderAdapter for ClaudeAdapter { base } - fn get_auth_headers(&self, auth: &AuthInfo) -> Vec<(http::HeaderName, http::HeaderValue)> { + fn get_auth_headers( + &self, + auth: &AuthInfo, + request_body: Option<&serde_json::Value>, + ) -> Vec<(http::HeaderName, http::HeaderValue)> { use http::{HeaderName, HeaderValue}; // 注意:anthropic-version 由 forwarder.rs 统一处理(透传客户端值或设置默认值) let bearer = format!("Bearer {}", auth.api_key); @@ -416,9 +578,14 @@ impl ProviderAdapter for ClaudeAdapter { ] } AuthStrategy::GitHubCopilot => { - // 生成请求追踪 ID - let request_id = uuid::Uuid::new_v4().to_string(); - vec![ + let request_id = generate_request_id(); + let initiator = if is_agent_initiated(request_body) { + "agent" + } else { + "user" + }; + + let mut headers = vec![ ( HeaderName::from_static("authorization"), HeaderValue::from_str(&bearer).unwrap(), @@ -450,7 +617,7 @@ impl ProviderAdapter for ClaudeAdapter { ), ( HeaderName::from_static("x-initiator"), - HeaderValue::from_static("user"), + HeaderValue::from_static(initiator), ), ( HeaderName::from_static("x-interaction-type"), @@ -468,7 +635,24 @@ impl ProviderAdapter for ClaudeAdapter { HeaderName::from_static("x-agent-task-id"), HeaderValue::from_str(&request_id).unwrap(), ), - ] + ]; + + // 添加 machine ID 和 session ID(如果存在) + if let Some(machine_id) = &auth.machine_id { + headers.push(( + HeaderName::from_static("vscode-machineid"), + HeaderValue::from_str(machine_id).unwrap(), + )); + } + + if let Some(session_id) = &auth.session_id { + headers.push(( + HeaderName::from_static("vscode-sessionid"), + HeaderValue::from_str(session_id).unwrap(), + )); + } + + headers } _ => vec![], } @@ -557,6 +741,103 @@ mod tests { } } + #[test] + fn test_is_agent_initiated_for_tool_result_continuation() { + let body = json!({ + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "tool-1", "name": "search", "input": {}} + ] + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "tool-1", "content": "ok"} + ] + } + ] + }); + + assert!(is_agent_initiated(Some(&body))); + } + + #[test] + fn test_is_agent_initiated_for_plain_user_message() { + let body = json!({ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"} + ] + } + ] + }); + + assert!(!is_agent_initiated(Some(&body))); + } + + #[test] + fn test_is_agent_initiated_for_responses_tool_output() { + let body = json!({ + "input": [ + {"type": "function_call_output", "call_id": "call-1", "output": "ok"} + ] + }); + + assert!(is_agent_initiated(Some(&body))); + } + + #[test] + fn test_is_agent_initiated_for_plain_follow_up_after_tool_history() { + let body = json!({ + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "tool-1", "name": "search", "input": {}} + ] + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "tool-1", "content": "ok"} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I found the result"} + ] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "请继续解释这个结果"} + ] + } + ] + }); + + assert!(!is_agent_initiated(Some(&body))); + } + + #[test] + fn test_is_agent_initiated_for_responses_plain_follow_up_after_tool_history() { + let body = json!({ + "input": [ + {"type": "function_call", "call_id": "call-1", "name": "search", "arguments": "{}"}, + {"type": "function_call_output", "call_id": "call-1", "output": "ok"}, + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I found the result"}]}, + {"type": "message", "role": "user", "content": [{"type": "input_text", "text": "请继续解释这个结果"}]} + ] + }); + + assert!(!is_agent_initiated(Some(&body))); + } + #[test] fn test_extract_base_url_from_env() { let adapter = ClaudeAdapter::new(); diff --git a/src-tauri/src/proxy/providers/codex.rs b/src-tauri/src/proxy/providers/codex.rs index eebd33704..600fc991f 100644 --- a/src-tauri/src/proxy/providers/codex.rs +++ b/src-tauri/src/proxy/providers/codex.rs @@ -9,6 +9,7 @@ use super::{AuthInfo, AuthStrategy, ProviderAdapter}; use crate::provider::Provider; use crate::proxy::error::ProxyError; use regex::Regex; +use serde_json::Value; use std::sync::LazyLock; /// 官方 Codex 客户端 User-Agent 正则 @@ -173,7 +174,11 @@ impl ProviderAdapter for CodexAdapter { url } - fn get_auth_headers(&self, auth: &AuthInfo) -> Vec<(http::HeaderName, http::HeaderValue)> { + fn get_auth_headers( + &self, + auth: &AuthInfo, + _request_body: Option<&Value>, + ) -> Vec<(http::HeaderName, http::HeaderValue)> { let bearer = format!("Bearer {}", auth.api_key); vec![( http::HeaderName::from_static("authorization"), diff --git a/src-tauri/src/proxy/providers/copilot_auth.rs b/src-tauri/src/proxy/providers/copilot_auth.rs index 44a232712..18fe7dbc8 100644 --- a/src-tauri/src/proxy/providers/copilot_auth.rs +++ b/src-tauri/src/proxy/providers/copilot_auth.rs @@ -295,6 +295,15 @@ struct GitHubAccountData { pub user: GitHubUser, /// 认证时间戳 pub authenticated_at: i64, + /// 机器 ID(基于 MAC 地址的 SHA256 哈希,每个账号独立) + #[serde(default)] + pub machine_id: Option, + /// 会话 ID(UUID + 时间戳,每小时刷新) + #[serde(default)] + pub session_id: Option, + /// Session ID 上次刷新时间 + #[serde(default)] + pub session_refreshed_at: Option, } /// 持久化存储结构(v3 多账号 + 默认账号格式) @@ -330,6 +339,8 @@ pub struct CopilotAuthManager { copilot_models: Arc>>>, /// Copilot API 端点缓存(key = GitHub user ID,从 /copilot_internal/user 获取) api_endpoints: Arc>>, + /// 运行时对话会话桥接(key = "{account_id}:{client_session_id}") + conversation_sessions: Arc>>, /// 每个账号的端点拉取锁,避免并发拉取重复打 GitHub API endpoint_locks: Arc>>>>, /// HTTP 客户端 @@ -340,9 +351,13 @@ pub struct CopilotAuthManager { pending_migration: Arc>>, /// 旧认证数据迁移失败时的状态消息 migration_error: Arc>>, + /// Session ID 刷新任务句柄(key = GitHub user ID) + session_refresh_tasks: Arc>>>, } impl CopilotAuthManager { + const SESSION_REFRESH_BASE_SECS: i64 = 3600; + /// 创建新的认证管理器 pub fn new(data_dir: PathBuf) -> Self { let storage_path = data_dir.join("copilot_auth.json"); @@ -354,11 +369,13 @@ impl CopilotAuthManager { copilot_tokens: Arc::new(RwLock::new(HashMap::new())), copilot_models: Arc::new(RwLock::new(HashMap::new())), api_endpoints: Arc::new(RwLock::new(HashMap::new())), + conversation_sessions: Arc::new(RwLock::new(HashMap::new())), endpoint_locks: Arc::new(RwLock::new(HashMap::new())), http_client: Client::new(), storage_path, pending_migration: Arc::new(RwLock::new(None)), migration_error: Arc::new(RwLock::new(None)), + session_refresh_tasks: Arc::new(RwLock::new(HashMap::new())), }; // 尝试从磁盘加载(同步,不发起网络请求) @@ -369,6 +386,25 @@ impl CopilotAuthManager { manager } + /// 启动已加载账号的后台 Session ID 刷新任务。 + /// + /// 这个方法要求调用方已经处于可用的异步运行时中。 + pub async fn initialize_background_tasks(&self) { + let account_ids = { + let accounts_guard = self.accounts.read().await; + accounts_guard.keys().cloned().collect::>() + }; + + for account_id in account_ids { + Self::spawn_session_refresh_task( + account_id, + Arc::clone(&self.accounts), + Arc::clone(&self.session_refresh_tasks), + ) + .await; + } + } + // ==================== 多账号管理方法 ==================== /// 列出所有已认证的账号 @@ -413,11 +449,28 @@ impl CopilotAuthManager { let mut api_endpoints = self.api_endpoints.write().await; api_endpoints.remove(account_id); } + { + let mut conversation_sessions = self.conversation_sessions.write().await; + let prefix = format!("{account_id}:"); + conversation_sessions.retain(|key, _| !key.starts_with(&prefix)); + } { let mut endpoint_locks = self.endpoint_locks.write().await; endpoint_locks.remove(account_id); } + // 停止 Session ID 刷新任务 + { + let mut tasks = self.session_refresh_tasks.write().await; + if let Some(task) = tasks.remove(account_id) { + task.abort(); + log::debug!( + "[CopilotAuth] 已停止账号 {} 的 Session ID 刷新任务", + account_id + ); + } + } + { let accounts = self.accounts.read().await; let mut default_account_id = self.default_account_id.write().await; @@ -437,14 +490,21 @@ impl CopilotAuthManager { &self, github_token: String, user: GitHubUser, + machine_id: String, ) -> Result { let account_id = user.id.to_string(); let now = chrono::Utc::now().timestamp(); + // 生成初始 Session ID + let session_id = Self::generate_session_id(); + let account_data = GitHubAccountData { github_token, user: user.clone(), authenticated_at: now, + machine_id: Some(machine_id), + session_id: Some(session_id.clone()), + session_refreshed_at: Some(now), }; let account = GitHubAccount { @@ -456,7 +516,7 @@ impl CopilotAuthManager { { let mut accounts = self.accounts.write().await; - accounts.insert(account_id, account_data); + accounts.insert(account_id.clone(), account_data); } { @@ -471,6 +531,14 @@ impl CopilotAuthManager { // 持久化 self.save_to_disk().await?; + // 启动 Session ID 刷新任务 + Self::spawn_session_refresh_task( + account_id.clone(), + Arc::clone(&self.accounts), + Arc::clone(&self.session_refresh_tasks), + ) + .await; + log::info!("[CopilotAuth] 添加账号成功: {}", user.login); Ok(account) @@ -584,8 +652,43 @@ impl CopilotAuthManager { self.fetch_copilot_token_with_github_token(&access_token, &user.id.to_string()) .await?; + let account_id = user.id.to_string(); + + // 检查账号是否已存在,以及是否需要重新生成 Machine ID + let machine_id = { + let accounts = self.accounts.read().await; + if let Some(existing_account) = accounts.get(&account_id) { + let now = chrono::Utc::now().timestamp(); + let time_since_auth = now - existing_account.authenticated_at; + + // 如果在 1 小时内重新登录,复用现有的 Machine ID + if time_since_auth < 3600 && existing_account.machine_id.is_some() { + log::info!( + "[CopilotAuth] 账号 {} 在 1 小时内重新登录,复用现有 Machine ID", + account_id + ); + existing_account.machine_id.clone().unwrap() + } else { + // 超过 1 小时,重新生成 Machine ID + let new_machine_id = self.generate_machine_id(&account_id); + log::debug!( + "[CopilotAuth] 账号 {} 超过 1 小时后重新登录,已生成新 Machine ID", + account_id + ); + new_machine_id + } + } else { + // 新账号,生成 Machine ID + let new_machine_id = self.generate_machine_id(&account_id); + log::debug!("[CopilotAuth] 为新账号 {} 生成 Machine ID", account_id); + new_machine_id + } + }; + // 添加账号 - let account = self.add_account_internal(access_token, user).await?; + let account = self + .add_account_internal(access_token, user, machine_id) + .await?; Ok(Some(account)) } @@ -844,6 +947,15 @@ impl CopilotAuthManager { } } + // 加锁前先检查账号是否存在,避免为不存在的账号永久插入锁 + { + let accounts = self.accounts.read().await; + if !accounts.contains_key(account_id) { + log::debug!("[CopilotAuth] 账号 {account_id} 不存在,使用默认 endpoint"); + return DEFAULT_COPILOT_API_ENDPOINT.to_string(); + } + } + // 用锁串行化同一账号的并发拉取,避免对 GitHub API 的重复请求 let lock = self.get_endpoint_lock(account_id).await; let _guard = lock.lock().await; @@ -944,7 +1056,6 @@ impl CopilotAuthManager { .or_insert_with(|| Arc::new(Mutex::new(()))), ) } - /// 获取认证状态(支持多账号) pub async fn get_status(&self) -> CopilotAuthStatus { // 确保迁移完成 @@ -1011,19 +1122,38 @@ impl CopilotAuthManager { let mut refresh_locks = self.refresh_locks.write().await; refresh_locks.clear(); } + { + let mut session_refresh_tasks = self.session_refresh_tasks.write().await; + for (_, task) in session_refresh_tasks.drain() { + task.abort(); + } + } // 清理 API 端点缓存 { let mut api_endpoints = self.api_endpoints.write().await; api_endpoints.clear(); } + { + let mut conversation_sessions = self.conversation_sessions.write().await; + conversation_sessions.clear(); + } { let mut endpoint_locks = self.endpoint_locks.write().await; endpoint_locks.clear(); } - // 最后删除存储文件 if self.storage_path.exists() { - std::fs::remove_file(&self.storage_path)?; + if let Err(err) = std::fs::remove_file(&self.storage_path) { + // NotFound 可以忽略(文件可能已被删除) + // 其他错误只记录警告,不阻止登出(内存状态已清理,用户体验上已登出) + if err.kind() != std::io::ErrorKind::NotFound { + log::warn!( + "[CopilotAuth] 删除认证文件失败 {}: {},内存状态已清理", + self.storage_path.display(), + err + ); + } + } } Ok(()) @@ -1075,6 +1205,86 @@ impl CopilotAuthManager { Self::fallback_default_account_id(&accounts) } + /// 获取账号的 machine ID 和 session ID(公共方法) + pub async fn get_account_ids( + &self, + account_id: Option<&str>, + ) -> (Option, Option) { + self.get_account_ids_for_conversation(account_id, None) + .await + } + + /// 获取账号的 machine ID 和对话级 session ID。 + /// + /// 当 `client_session_id` 存在时,为同一账号 + 客户端会话复用稳定的 upstream session; + /// 否则回退到账号级 session_id(向后兼容)。 + pub async fn get_account_ids_for_conversation( + &self, + account_id: Option<&str>, + client_session_id: Option<&str>, + ) -> (Option, Option) { + match account_id { + Some(id) => self.resolve_account_ids(id, client_session_id).await, + None => { + // 使用默认账号 + let default_id = self.resolve_default_account_id().await; + if let Some(id) = default_id { + self.resolve_account_ids(&id, client_session_id).await + } else { + (None, None) + } + } + } + } + + async fn resolve_account_ids( + &self, + account_id: &str, + client_session_id: Option<&str>, + ) -> (Option, Option) { + let machine_id = { + let accounts = self.accounts.read().await; + accounts + .get(account_id) + .and_then(|account_data| account_data.machine_id.clone()) + }; + + let Some(machine_id) = machine_id else { + return (None, None); + }; + + let session_id = match client_session_id + .map(str::trim) + .filter(|value| !value.is_empty()) + { + Some(client_session_id) => { + let key = format!("{account_id}:{client_session_id}"); + { + let conversation_sessions = self.conversation_sessions.read().await; + if let Some(existing) = conversation_sessions.get(&key) { + return (Some(machine_id), Some(existing.clone())); + } + } + + let new_session_id = Self::generate_session_id(); + let mut conversation_sessions = self.conversation_sessions.write().await; + let session_id = conversation_sessions + .entry(key) + .or_insert_with(|| new_session_id.clone()) + .clone(); + Some(session_id) + } + None => { + let accounts = self.accounts.read().await; + accounts + .get(account_id) + .and_then(|account_data| account_data.session_id.clone()) + } + }; + + (Some(machine_id), session_id) + } + async fn get_refresh_lock(&self, account_id: &str) -> Arc> { { let refresh_locks = self.refresh_locks.read().await; @@ -1091,6 +1301,193 @@ impl CopilotAuthManager { ) } + async fn apply_refreshed_session_id( + account_id: &str, + new_session_id: String, + now: i64, + accounts: &Arc>>, + ) -> bool { + let mut accounts_guard = accounts.write().await; + if let Some(account_data) = accounts_guard.get_mut(account_id) { + account_data.session_id = Some(new_session_id.clone()); + account_data.session_refreshed_at = Some(now); + } else { + return false; + } + + true + } + + fn machine_base_path(&self) -> PathBuf { + self.storage_path + .parent() + .unwrap_or_else(|| std::path::Path::new(".")) + .join("copilot_machine_base.txt") + } + + fn fallback_machine_base(&self) -> String { + let machine_base_path = self.machine_base_path(); + + if let Ok(existing) = std::fs::read_to_string(&machine_base_path) { + let trimmed = existing.trim(); + if !trimmed.is_empty() { + return trimmed.to_string(); + } + } + + let host_hint = std::env::var("HOSTNAME") + .ok() + .filter(|value| !value.trim().is_empty()) + .or_else(|| { + std::env::var("COMPUTERNAME") + .ok() + .filter(|value| !value.trim().is_empty()) + }) + .unwrap_or_else(|| "unknown-host".to_string()); + let machine_base = format!("{}-{}", host_hint, uuid::Uuid::new_v4()); + + if let Some(parent) = machine_base_path.parent() { + let _ = std::fs::create_dir_all(parent); + } + if let Err(err) = std::fs::write(&machine_base_path, &machine_base) { + log::warn!( + "[CopilotAuth] 无法持久化 machine base {}: {}", + machine_base_path.display(), + err + ); + } + + machine_base + } + + /// 生成机器 ID(每个账号独立的 Machine ID) + /// + /// 为了避免账号间关联,每个账号使用不同的 Machine ID: + /// - 基于 MAC 地址(或随机 UUID)+ 账号 ID 生成 + /// - 对组合值进行 SHA256 哈希 + fn generate_machine_id(&self, account_id: &str) -> String { + use sha2::{Digest, Sha256}; + + // 尝试获取 MAC 地址作为基础 + let base = match mac_address::get_mac_address() { + Ok(Some(mac)) => { + // MAC 地址格式:AA:BB:CC:DD:EE:FF + mac.to_string() + } + _ => { + // 回退:使用主机唯一且持久化的基础值,避免不同设备碰撞 + log::warn!("[CopilotAuth] 无法获取 MAC 地址,使用持久化 machine base"); + self.fallback_machine_base() + } + }; + + // 组合基础值和账号 ID,确保每个账号有不同的 Machine ID + let source = format!("{}-{}", base, account_id); + + // SHA256 哈希 + let mut hasher = Sha256::new(); + hasher.update(source.as_bytes()); + let result = hasher.finalize(); + + // 转换为十六进制字符串 + format!("{:x}", result) + } + + /// 生成会话 ID(UUID + 时间戳) + /// + /// 参考 caozhiyuan/copilot-api 的实现: + /// state.vsCodeSessionId = randomUUID() + Date.now().toString() + fn generate_session_id() -> String { + let uuid = uuid::Uuid::new_v4().to_string(); + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + + format!("{}{}", uuid, timestamp) + } + + fn should_refresh_session_id(session_refreshed_at: Option, now: i64) -> bool { + match session_refreshed_at { + Some(refreshed_at) => { + now.saturating_sub(refreshed_at) >= Self::SESSION_REFRESH_BASE_SECS + } + None => true, + } + } + + /// 启动 Session ID 定期刷新任务(每小时刷新) + /// + /// 参考 caozhiyuan/copilot-api 的实现: + /// - 基础间隔:1 小时 + /// - 添加 0-10% 随机抖动避免集中刷新 + async fn spawn_session_refresh_task( + account_id: String, + accounts: Arc>>, + session_refresh_tasks: Arc>>>, + ) { + use tokio::time::{sleep, Duration}; + + const SESSION_REFRESH_BASE_SECS: u64 = 3600; // 1 小时 + + // 停止旧任务(如果存在) + { + let mut tasks = session_refresh_tasks.write().await; + if let Some(old_task) = tasks.remove(&account_id) { + old_task.abort(); + } + } + + let account_id_clone = account_id.clone(); + let accounts_clone = Arc::clone(&accounts); + + // 启动新任务 + let task = tokio::spawn(async move { + loop { + // 计算下次刷新时间(带随机抖动) + // 使用 SystemTime 生成随机抖动,避免 thread_rng 的 Send 问题 + let jitter_nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .subsec_nanos(); + let jitter_percent = (jitter_nanos % 100) as f64 / 1000.0; // 0-10% + let jitter_secs = (SESSION_REFRESH_BASE_SECS as f64 * jitter_percent) as u64; + let delay_secs = SESSION_REFRESH_BASE_SECS + jitter_secs; + + sleep(Duration::from_secs(delay_secs)).await; + + // 刷新 Session ID + let new_session_id = Self::generate_session_id(); + let now = chrono::Utc::now().timestamp(); + + if Self::apply_refreshed_session_id( + &account_id_clone, + new_session_id, + now, + &accounts_clone, + ) + .await + { + log::debug!( + "[CopilotAuth] 账号 {} 的 Session ID 已刷新", + account_id_clone + ); + } else { + // 账号已删除,退出任务 + log::debug!( + "[CopilotAuth] 账号 {} 已删除,停止 Session ID 刷新任务", + account_id_clone + ); + break; + } + } + }); + + // 保存任务句柄 + let mut tasks = session_refresh_tasks.write().await; + tasks.insert(account_id, task); + } + async fn set_migration_error(&self, message: Option) { let mut migration_error = self.migration_error.write().await; *migration_error = message; @@ -1245,11 +1642,40 @@ impl CopilotAuthManager { } let content = std::fs::read_to_string(&self.storage_path)?; - let store: CopilotAuthStore = serde_json::from_str(&content) + let mut store: CopilotAuthStore = serde_json::from_str(&content) .map_err(|e| CopilotAuthError::ParseError(e.to_string()))?; if store.version >= 2 { - // v2 多账号格式 + // v2/v3 多账号格式 + + // 为没有 machine_id 的账号补齐,并在启动恢复时纠正过期的 session_id + let mut needs_save = false; + let now = chrono::Utc::now().timestamp(); + for (account_id, account_data) in store.accounts.iter_mut() { + if account_data.machine_id.is_none() { + account_data.machine_id = Some(self.generate_machine_id(account_id)); + log::debug!("[CopilotAuth] 为现有账号 {} 生成 Machine ID", account_id); + needs_save = true; + } + + let had_session_id = account_data.session_id.is_some(); + if !had_session_id + || Self::should_refresh_session_id(account_data.session_refreshed_at, now) + { + account_data.session_id = Some(Self::generate_session_id()); + account_data.session_refreshed_at = Some(now); + if had_session_id { + log::info!( + "[CopilotAuth] 为现有账号 {} 刷新过期 Session ID", + account_id + ); + } else { + log::info!("[CopilotAuth] 为现有账号 {} 生成 Session ID", account_id); + } + needs_save = true; + } + } + if let Ok(mut accounts) = self.accounts.try_write() { *accounts = store.accounts; log::info!("[CopilotAuth] 从磁盘加载 {} 个账号", accounts.len()); @@ -1262,6 +1688,20 @@ impl CopilotAuthManager { } } } + + // 如果有修改,直接同步保存回磁盘,避免在无异步运行时的启动阶段触发 panic + if needs_save { + let store = CopilotAuthStore { + version: 3, + accounts: self.accounts.blocking_read().clone(), + default_account_id: self.default_account_id.blocking_read().clone(), + github_token: None, + authenticated_at: None, + }; + if let Ok(content) = serde_json::to_string_pretty(&store) { + let _ = std::fs::write(&self.storage_path, content); + } + } } else if store.github_token.is_some() { // v1 单账号格式,标记待迁移 log::info!("[CopilotAuth] 检测到旧格式,将在首次访问时迁移"); @@ -1296,8 +1736,14 @@ impl CopilotAuthManager { log::warn!("[CopilotAuth] 迁移时验证 Copilot 订阅失败: {e}"); } + // 生成 Machine ID(迁移时也需要,基于账号 ID) + let account_id = user.id.to_string(); + let machine_id = self.generate_machine_id(&account_id); + log::debug!("[CopilotAuth] 为迁移账号 {} 生成 Machine ID", account_id); + // 添加账号 - self.add_account_internal(legacy_token, user).await?; + self.add_account_internal(legacy_token, user, machine_id) + .await?; self.set_migration_error(None).await; log::info!("[CopilotAuth] 旧格式迁移完成"); @@ -1353,6 +1799,27 @@ mod tests { use super::*; use tempfile::tempdir; + fn test_account_data( + token: &str, + login: &str, + id: u64, + avatar_url: Option<&str>, + authenticated_at: i64, + ) -> GitHubAccountData { + GitHubAccountData { + github_token: token.to_string(), + user: GitHubUser { + login: login.to_string(), + id, + avatar_url: avatar_url.map(str::to_string), + }, + authenticated_at, + machine_id: Some(format!("machine-{id}")), + session_id: Some(format!("session-{id}")), + session_refreshed_at: Some(authenticated_at), + } + } + #[test] fn test_copilot_token_expiry() { let now = chrono::Utc::now().timestamp(); @@ -1407,32 +1874,99 @@ mod tests { assert_eq!(parsed.accounts[0].login, "testuser"); } + #[tokio::test] + async fn test_get_account_ids_for_conversation_reuses_session_per_client_session() { + let temp = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp.path().to_path_buf()); + + { + let mut accounts = manager.accounts.write().await; + accounts.insert( + "12345".to_string(), + test_account_data("token", "testuser", 12345, None, 1234567890), + ); + } + { + let mut default_account_id = manager.default_account_id.write().await; + *default_account_id = Some("12345".to_string()); + } + + let (machine_a, session_a1) = manager + .get_account_ids_for_conversation(Some("12345"), Some("conversation-1")) + .await; + let (machine_b, session_a2) = manager + .get_account_ids_for_conversation(Some("12345"), Some("conversation-1")) + .await; + let (_, session_b) = manager + .get_account_ids_for_conversation(Some("12345"), Some("conversation-2")) + .await; + + assert_eq!(machine_a, machine_b); + assert_eq!(session_a1, session_a2); + assert_ne!(session_a1, session_b); + } + + #[tokio::test] + async fn test_conversation_sessions_remain_isolated_after_account_refresh() { + let temp = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp.path().to_path_buf()); + + { + let mut accounts = manager.accounts.write().await; + accounts.insert( + "12345".to_string(), + test_account_data("token", "testuser", 12345, None, 1234567890), + ); + } + + let (_, original_session_a) = manager + .get_account_ids_for_conversation(Some("12345"), Some("conversation-1")) + .await; + let (_, original_session_b) = manager + .get_account_ids_for_conversation(Some("12345"), Some("conversation-2")) + .await; + let refreshed_session = "refreshed-session".to_string(); + let updated = CopilotAuthManager::apply_refreshed_session_id( + "12345", + refreshed_session.clone(), + chrono::Utc::now().timestamp(), + &manager.accounts, + ) + .await; + + let (_, updated_session_a) = manager + .get_account_ids_for_conversation(Some("12345"), Some("conversation-1")) + .await; + let (_, updated_session_b) = manager + .get_account_ids_for_conversation(Some("12345"), Some("conversation-2")) + .await; + let (_, fallback_session) = manager + .get_account_ids_for_conversation(Some("12345"), None) + .await; + + assert!(updated); + assert_eq!(updated_session_a, original_session_a); + assert_eq!(updated_session_b, original_session_b); + assert_ne!(updated_session_a, updated_session_b); + assert_eq!(fallback_session, Some(refreshed_session)); + } + #[test] fn test_multi_account_store_serialization() { let mut accounts = HashMap::new(); accounts.insert( "12345".to_string(), - GitHubAccountData { - github_token: "gho_test_token".to_string(), - user: GitHubUser { - login: "alice".to_string(), - id: 12345, - avatar_url: Some("https://example.com/alice.png".to_string()), - }, - authenticated_at: 1700000000, - }, + test_account_data( + "gho_test_token", + "alice", + 12345, + Some("https://example.com/alice.png"), + 1700000000, + ), ); accounts.insert( "67890".to_string(), - GitHubAccountData { - github_token: "gho_test_token_2".to_string(), - user: GitHubUser { - login: "bob".to_string(), - id: 67890, - avatar_url: None, - }, - authenticated_at: 1700000001, - }, + test_account_data("gho_test_token_2", "bob", 67890, None, 1700000001), ); let store = CopilotAuthStore { @@ -1471,15 +2005,13 @@ mod tests { #[test] fn test_github_account_from_data() { - let data = GitHubAccountData { - github_token: "gho_test".to_string(), - user: GitHubUser { - login: "testuser".to_string(), - id: 99999, - avatar_url: Some("https://example.com/avatar.png".to_string()), - }, - authenticated_at: 1700000000, - }; + let data = test_account_data( + "gho_test", + "testuser", + 99999, + Some("https://example.com/avatar.png"), + 1700000000, + ); let account = GitHubAccount::from(&data); assert_eq!(account.id, "99999"); @@ -1496,27 +2028,11 @@ mod tests { let mut accounts = HashMap::new(); accounts.insert( "12345".to_string(), - GitHubAccountData { - github_token: "gho_test_token".to_string(), - user: GitHubUser { - login: "alice".to_string(), - id: 12345, - avatar_url: None, - }, - authenticated_at: 1700000000, - }, + test_account_data("gho_test_token", "alice", 12345, None, 1700000000), ); accounts.insert( "67890".to_string(), - GitHubAccountData { - github_token: "gho_test_token_2".to_string(), - user: GitHubUser { - login: "bob".to_string(), - id: 67890, - avatar_url: None, - }, - authenticated_at: 1700000001, - }, + test_account_data("gho_test_token_2", "bob", 67890, None, 1700000001), ); assert_eq!( @@ -1525,6 +2041,44 @@ mod tests { ); } + #[test] + fn test_load_from_disk_refreshes_stale_session_id() { + let temp_dir = tempdir().unwrap(); + let storage_path = temp_dir.path().join("copilot_auth.json"); + let stale_refreshed_at = + chrono::Utc::now().timestamp() - CopilotAuthManager::SESSION_REFRESH_BASE_SECS - 10; + let stale_session_id = "stale-session".to_string(); + + let mut accounts = HashMap::new(); + accounts.insert( + "12345".to_string(), + GitHubAccountData { + session_id: Some(stale_session_id.clone()), + session_refreshed_at: Some(stale_refreshed_at), + ..test_account_data("gho_test", "alice", 12345, None, 1700000000) + }, + ); + + let store = CopilotAuthStore { + version: 3, + accounts, + default_account_id: Some("12345".to_string()), + github_token: None, + authenticated_at: None, + }; + std::fs::write(&storage_path, serde_json::to_string_pretty(&store).unwrap()).unwrap(); + + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + let accounts = manager.accounts.blocking_read(); + let restored = accounts.get("12345").unwrap(); + let refreshed_session_id = restored.session_id.as_ref().unwrap(); + let refreshed_at = restored.session_refreshed_at.unwrap(); + + assert_ne!(refreshed_session_id, &stale_session_id); + assert!(refreshed_at > stale_refreshed_at); + } + #[tokio::test] async fn test_get_model_vendor_from_cache() { let temp_dir = tempdir().unwrap(); @@ -1538,15 +2092,7 @@ mod tests { let mut accounts = manager.accounts.write().await; accounts.insert( "12345".to_string(), - GitHubAccountData { - github_token: "gho_test".to_string(), - user: GitHubUser { - login: "alice".to_string(), - id: 12345, - avatar_url: None, - }, - authenticated_at: 1700000000, - }, + test_account_data("gho_test", "alice", 12345, None, 1700000000), ); } { @@ -1622,15 +2168,7 @@ mod tests { let mut accounts = manager.accounts.write().await; accounts.insert( "12345".to_string(), - GitHubAccountData { - github_token: "gho_test".to_string(), - user: GitHubUser { - login: "alice".to_string(), - id: 12345, - avatar_url: None, - }, - authenticated_at: 1700000000, - }, + test_account_data("gho_test", "alice", 12345, None, 1700000000), ); } // 设置 API endpoint 缓存 @@ -1656,15 +2194,7 @@ mod tests { let mut accounts = manager.accounts.write().await; accounts.insert( "12345".to_string(), - GitHubAccountData { - github_token: "gho_test".to_string(), - user: GitHubUser { - login: "alice".to_string(), - id: 12345, - avatar_url: None, - }, - authenticated_at: 1700000000, - }, + test_account_data("gho_test", "alice", 12345, None, 1700000000), ); } // 设置 API endpoint 缓存 @@ -1726,6 +2256,54 @@ mod tests { } } + #[tokio::test] + async fn test_clear_auth_removes_persisted_auth_file() { + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + { + let mut accounts = manager.accounts.write().await; + accounts.insert( + "12345".to_string(), + test_account_data("gho_test", "alice", 12345, None, 1700000000), + ); + } + manager.save_to_disk().await.unwrap(); + assert!(manager.storage_path.exists()); + + manager.clear_auth().await.unwrap(); + + assert!( + !manager.storage_path.exists(), + "persisted auth file should be removed when clearing all auth" + ); + } + + #[tokio::test] + async fn test_clear_auth_aborts_session_refresh_tasks() { + let temp_dir = tempdir().unwrap(); + let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); + + let task = tokio::spawn(async { + tokio::time::sleep(tokio::time::Duration::from_secs(3600)).await; + }); + let abort_handle = task.abort_handle(); + + { + let mut session_refresh_tasks = manager.session_refresh_tasks.write().await; + session_refresh_tasks.insert("12345".to_string(), task); + } + + manager.clear_auth().await.unwrap(); + tokio::task::yield_now().await; + + { + let session_refresh_tasks = manager.session_refresh_tasks.read().await; + assert!(session_refresh_tasks.is_empty()); + } + assert!(abort_handle.is_finished()); + } + #[tokio::test] async fn test_clear_auth_cleans_memory_even_when_file_removal_fails() { let temp_dir = tempdir().unwrap(); @@ -1738,15 +2316,7 @@ mod tests { let mut accounts = manager.accounts.write().await; accounts.insert( "12345".to_string(), - GitHubAccountData { - github_token: "gho_test".to_string(), - user: GitHubUser { - login: "alice".to_string(), - id: 12345, - avatar_url: None, - }, - authenticated_at: 1700000000, - }, + test_account_data("gho_test", "alice", 12345, None, 1700000000), ); } { @@ -1762,8 +2332,7 @@ mod tests { } let result = manager.clear_auth().await; - // Should still return an error for the file deletion failure - assert!(result.is_err()); + assert!(result.is_ok()); // But memory state should already be cleaned let accounts = manager.accounts.read().await; diff --git a/src-tauri/src/proxy/providers/gemini.rs b/src-tauri/src/proxy/providers/gemini.rs index e0e88feb8..832e61355 100644 --- a/src-tauri/src/proxy/providers/gemini.rs +++ b/src-tauri/src/proxy/providers/gemini.rs @@ -9,6 +9,7 @@ use super::{AuthInfo, AuthStrategy, ProviderAdapter, ProviderType}; use crate::provider::Provider; use crate::proxy::error::ProxyError; +use serde_json::Value; /// Gemini 适配器 pub struct GeminiAdapter; @@ -216,7 +217,11 @@ impl ProviderAdapter for GeminiAdapter { url } - fn get_auth_headers(&self, auth: &AuthInfo) -> Vec<(http::HeaderName, http::HeaderValue)> { + fn get_auth_headers( + &self, + auth: &AuthInfo, + _request_body: Option<&Value>, + ) -> Vec<(http::HeaderName, http::HeaderValue)> { use http::{HeaderName, HeaderValue}; match auth.strategy { AuthStrategy::GoogleOAuth => { diff --git a/src-tauri/src/services/provider/mod.rs b/src-tauri/src/services/provider/mod.rs index fd0a4d03f..95247fb41 100644 --- a/src-tauri/src/services/provider/mod.rs +++ b/src-tauri/src/services/provider/mod.rs @@ -49,6 +49,51 @@ pub struct SwitchResult { pub warnings: Vec, } +fn proxy_required_reason(app_type: &AppType, provider: &Provider) -> Option<&'static str> { + if provider.category.as_deref() == Some("official") { + return None; + } + + let meta = provider.meta.as_ref(); + let is_copilot_provider = matches!( + (app_type, meta.and_then(|m| m.provider_type.as_deref())), + (AppType::Claude, Some("github_copilot")) + ); + + if is_copilot_provider { + return Some("使用 GitHub Copilot 作为 Claude 供应商"); + } + + match ( + app_type, + meta.and_then(|m| m.api_format.as_deref()), + meta.and_then(|m| m.is_full_url), + ) { + (AppType::Claude, Some("openai_chat"), _) => Some("使用 OpenAI Chat 接口格式"), + (AppType::Claude, Some("openai_responses"), _) => Some("使用 OpenAI Responses 接口格式"), + (AppType::Claude | AppType::Codex, _, Some(true)) => Some("开启了完整 URL 连接模式"), + _ => None, + } +} + +fn ensure_proxy_ready_for_switch( + state: &AppState, + app_type: &AppType, + provider: &Provider, +) -> Result<(), AppError> { + let Some(reason) = proxy_required_reason(app_type, provider) else { + return Ok(()); + }; + + if futures::executor::block_on(state.proxy_service.is_running()) { + return Ok(()); + } + + Err(AppError::Message(format!( + "此供应商{reason},需要代理服务才能正常使用,请先启动代理" + ))) +} + #[cfg(test)] mod tests { use super::*; @@ -1379,6 +1424,8 @@ impl ProviderService { .get(id) .ok_or_else(|| AppError::Message(format!("供应商 {id} 不存在")))?; + ensure_proxy_ready_for_switch(state, &app_type, _provider)?; + // OMO providers are switched through their own exclusive path. if matches!(app_type, AppType::OpenCode) && _provider.category.as_deref() == Some("omo") { return Self::switch_normal(state, app_type, id, &providers); diff --git a/src-tauri/src/services/stream_check.rs b/src-tauri/src/services/stream_check.rs index 5e2203a05..b062df08a 100644 --- a/src-tauri/src/services/stream_check.rs +++ b/src-tauri/src/services/stream_check.rs @@ -404,8 +404,8 @@ impl StreamCheckService { .header("x-initiator", "user") .header("x-interaction-type", "conversation-agent") .header("x-vscode-user-agent-library-version", "electron-fetch") - .header("x-request-id", &request_id) - .header("x-agent-task-id", &request_id); + .header("x-request-id", request_id.as_str()) + .header("x-agent-task-id", request_id.as_str()); } else if is_openai_chat || is_openai_responses { // OpenAI-compatible targets: Bearer auth + SSE headers only request_builder = request_builder diff --git a/src-tauri/tests/provider_service.rs b/src-tauri/tests/provider_service.rs index c4400c230..df3fe8f64 100644 --- a/src-tauri/tests/provider_service.rs +++ b/src-tauri/tests/provider_service.rs @@ -238,6 +238,68 @@ command = "say" ); } +#[test] +fn provider_service_switch_claude_copilot_requires_running_proxy() { + let _guard = test_mutex().lock().expect("acquire test mutex"); + reset_test_fs(); + let _home = ensure_test_home(); + + let mut initial_config = MultiAppConfig::default(); + { + let manager = initial_config + .get_manager_mut(&AppType::Claude) + .expect("claude manager"); + manager.providers.insert( + "copilot-provider".to_string(), + Provider { + id: "copilot-provider".to_string(), + name: "Copilot".to_string(), + settings_config: json!({ + "env": { + "ANTHROPIC_BASE_URL": "https://api.githubcopilot.com" + } + }), + website_url: None, + category: Some("custom".to_string()), + created_at: None, + sort_index: None, + notes: None, + meta: Some(ProviderMeta { + provider_type: Some("github_copilot".to_string()), + ..Default::default() + }), + icon: None, + icon_color: None, + in_failover_queue: false, + }, + ); + } + + let state = create_test_state_with_config(&initial_config).expect("create test state"); + + let err = ProviderService::switch(&state, AppType::Claude, "copilot-provider") + .expect_err("switch should be blocked when proxy is not running"); + + assert!( + err.to_string().contains("需要代理服务才能正常使用"), + "error should explain proxy requirement, got: {err}" + ); + + let current_id = state + .db + .get_current_provider(AppType::Claude.as_str()) + .expect("get current provider"); + assert!( + current_id.is_none(), + "current provider should remain unchanged when switch is blocked" + ); + + assert!( + !get_claude_settings_path().exists(), + "Claude live config should not be written when switch is blocked" + ); +} + #[test] fn sync_current_provider_for_app_keeps_live_takeover_and_updates_restore_backup() { let _guard = test_mutex().lock().expect("acquire test mutex"); diff --git a/src/lib/query/mutations.ts b/src/lib/query/mutations.ts index e5ab2815c..520632ed9 100644 --- a/src/lib/query/mutations.ts +++ b/src/lib/query/mutations.ts @@ -9,6 +9,15 @@ import { extractErrorMessage } from "@/utils/errorUtils"; import { generateUUID } from "@/utils/uuid"; import { openclawKeys } from "@/hooks/useOpenClaw"; +function isProxyRequirementSwitchError(detail: string): boolean { + return [ + "需要代理服务", + "proxy service", + "プロキシサービス", + "Start the proxy first", + ].some((needle) => detail.includes(needle)); +} + export const useAddProviderMutation = (appId: AppId) => { const queryClient = useQueryClient(); const { t } = useTranslation(); @@ -256,7 +265,9 @@ export const useSwitchProviderMutation = (appId: AppId) => { }, onError: (error: Error) => { const detail = extractErrorMessage(error) || t("common.unknown"); - + if (isProxyRequirementSwitchError(detail)) { + return; + } toast.error( t("notifications.switchFailedTitle", { defaultValue: "切换失败" }), {