Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src-tauri/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ indexmap = { version = "2", features = ["serde"] }
rust_decimal = "1.33"
uuid = { version = "1.11", features = ["v4"] }
sha2 = "0.10"
mac_address = "1.1"
json5 = "0.4"
json-five = "0.3.1"

Expand Down
13 changes: 11 additions & 2 deletions src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,17 @@ pub fn run() {
use tokio::sync::RwLock;

let app_config_dir = crate::config::get_app_config_dir();
let copilot_auth_manager = CopilotAuthManager::new(app_config_dir);
app.manage(CopilotAuthState(Arc::new(RwLock::new(copilot_auth_manager))));
let copilot_auth_manager = Arc::new(RwLock::new(CopilotAuthManager::new(
app_config_dir,
)));
app.manage(CopilotAuthState(Arc::clone(&copilot_auth_manager)));
tauri::async_runtime::spawn(async move {
copilot_auth_manager
.read()
.await
.initialize_background_tasks()
.await;
});
log::info!("✓ CopilotAuthManager initialized");
}

Expand Down
184 changes: 134 additions & 50 deletions src-tauri/src/proxy/forwarder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ pub struct ForwardError {
pub provider: Option<Provider>,
}

pub struct ForwardRequestInput {
pub endpoint: String,
pub body: Value,
pub headers: axum::http::HeaderMap,
pub extensions: Extensions,
pub client_session_id: Option<String>,
}

struct ForwardMetadata<'a> {
headers: &'a axum::http::HeaderMap,
extensions: &'a Extensions,
client_session_id: Option<&'a str>,
}

pub struct RequestForwarder {
/// 共享的 ProviderRouter(持有熔断器状态)
router: Arc<ProviderRouter>,
Expand Down Expand Up @@ -99,12 +113,16 @@ impl RequestForwarder {
pub async fn forward_with_retry(
&self,
app_type: &AppType,
endpoint: &str,
body: Value,
headers: axum::http::HeaderMap,
extensions: Extensions,
request: ForwardRequestInput,
providers: Vec<Provider>,
) -> Result<ForwardResult, ForwardError> {
let ForwardRequestInput {
endpoint,
body,
headers,
extensions,
client_session_id,
} = request;
// 获取适配器
let adapter = get_adapter(app_type);
let app_type_str = app_type.as_str();
Expand Down Expand Up @@ -176,10 +194,13 @@ impl RequestForwarder {
match self
.forward(
provider,
endpoint,
&endpoint,
&provider_body,
&headers,
&extensions,
ForwardMetadata {
headers: &headers,
extensions: &extensions,
client_session_id: client_session_id.as_deref(),
},
adapter.as_ref(),
)
.await
Expand Down Expand Up @@ -306,10 +327,13 @@ impl RequestForwarder {
match self
.forward(
provider,
endpoint,
&endpoint,
&provider_body,
&headers,
&extensions,
ForwardMetadata {
headers: &headers,
extensions: &extensions,
client_session_id: client_session_id.as_deref(),
},
adapter.as_ref(),
)
.await
Expand Down Expand Up @@ -505,10 +529,13 @@ impl RequestForwarder {
match self
.forward(
provider,
endpoint,
&endpoint,
&provider_body,
&headers,
&extensions,
ForwardMetadata {
headers: &headers,
extensions: &extensions,
client_session_id: client_session_id.as_deref(),
},
adapter.as_ref(),
)
.await
Expand Down Expand Up @@ -746,14 +773,16 @@ impl RequestForwarder {
provider: &Provider,
endpoint: &str,
body: &Value,
headers: &axum::http::HeaderMap,
extensions: &Extensions,
metadata: ForwardMetadata<'_>,
adapter: &dyn ProviderAdapter,
) -> Result<(ProxyResponse, Option<String>), ProxyError> {
let headers = metadata.headers;
let extensions = metadata.extensions;
let client_session_id = metadata.client_session_id;
// 使用适配器提取 base_url
let mut base_url = adapter.extract_base_url(provider)?;

let is_full_url = provider
let has_explicit_full_url = provider
.meta
.as_ref()
.and_then(|meta| meta.is_full_url)
Expand Down Expand Up @@ -830,36 +859,6 @@ impl RequestForwarder {
} else {
None
};

// GitHub Copilot 动态 endpoint 路由
// 从 CopilotAuthManager 获取缓存的 API endpoint(支持企业版等非默认 endpoint)
if is_copilot && !is_full_url {
if let Some(app_handle) = &self.app_handle {
let copilot_state = app_handle.state::<CopilotAuthState>();
let copilot_auth = copilot_state.0.read().await;

// 从 provider.meta 获取关联的 GitHub 账号 ID
let account_id = provider
.meta
.as_ref()
.and_then(|m| m.managed_account_id_for("github_copilot"));

let dynamic_endpoint = match &account_id {
Some(id) => copilot_auth.get_api_endpoint(id).await,
None => copilot_auth.get_default_api_endpoint().await,
};

// 只在动态 endpoint 与当前 base_url 不同时替换
if dynamic_endpoint != base_url {
log::debug!(
"[Copilot] 使用动态 API endpoint: {} (原: {})",
dynamic_endpoint,
base_url
);
base_url = dynamic_endpoint;
}
}
}
let resolved_claude_api_format = if adapter.name() == "Claude" {
Some(
self.resolve_claude_api_format(provider, &mapped_body, is_copilot)
Expand Down Expand Up @@ -887,6 +886,34 @@ impl RequestForwarder {
)
};

let is_full_url = has_explicit_full_url
|| is_legacy_full_url_for_endpoint(&base_url, &effective_endpoint);
// GitHub Copilot 动态 endpoint 路由
// 从 CopilotAuthManager 获取缓存的 API endpoint(支持企业版等非默认 endpoint)
if is_copilot && !is_full_url {
if let Some(app_handle) = &self.app_handle {
let copilot_state = app_handle.state::<CopilotAuthState>();
let copilot_auth = copilot_state.0.read().await;

// 从 provider.meta 获取关联的 GitHub 账号 ID
let account_id = provider
.meta
.as_ref()
.and_then(|m| m.managed_account_id_for("github_copilot"));

let dynamic_endpoint = match &account_id {
Some(id) => copilot_auth.get_api_endpoint(id).await,
None => copilot_auth.get_default_api_endpoint().await,
};

// 只在动态 endpoint 与当前 base_url 不同时替换
if dynamic_endpoint != base_url {
log::debug!("[Copilot] 已启用动态 API endpoint");
base_url = dynamic_endpoint;
}
}
}

let url = if is_full_url {
append_query_to_full_url(&base_url, passthrough_query.as_deref())
} else {
Expand Down Expand Up @@ -946,10 +973,27 @@ impl RequestForwarder {

match token_result {
Ok(token) => {
auth = AuthInfo::new(token, AuthStrategy::GitHubCopilot);
// 获取 machine ID 和 session ID
let (machine_id, session_id) = copilot_auth
.get_account_ids_for_conversation(
account_id.as_deref(),
client_session_id,
)
.await;

// 使用新的构造函数
auth = AuthInfo::new_with_ids(
token,
AuthStrategy::GitHubCopilot,
machine_id.clone(),
session_id.clone(),
);

log::debug!(
"[Copilot] 成功获取 Copilot token (account={})",
account_id.as_deref().unwrap_or("default")
"[Copilot] 成功获取 Copilot token (account={}, has_machine_id={}, has_session_id={})",
account_id.as_deref().unwrap_or("default"),
machine_id.is_some(),
session_id.is_some(),
);
}
Err(e) => {
Expand All @@ -969,7 +1013,7 @@ impl RequestForwarder {
));
}
}
adapter.get_auth_headers(&auth)
adapter.get_auth_headers(&auth, Some(&filtered_body))
} else {
Vec::new()
};
Expand Down Expand Up @@ -1008,6 +1052,9 @@ impl RequestForwarder {
"x-vscode-user-agent-library-version",
"x-request-id",
"x-agent-task-id",
// Machine ID 和 Session ID
"vscode-machineid",
"vscode-sessionid",
]
} else {
&[]
Expand Down Expand Up @@ -1573,6 +1620,23 @@ fn append_query_to_full_url(base_url: &str, query: Option<&str>) -> String {
}
}

fn is_legacy_full_url_for_endpoint(base_url: &str, endpoint: &str) -> bool {
let Ok(parsed) = url::Url::parse(base_url) else {
return false;
};

let path = parsed.path().trim_end_matches('/');
let endpoint_path = split_endpoint_and_query(endpoint).0.trim_end_matches('/');

if path.is_empty() || path == "/" || endpoint_path.is_empty() || endpoint_path == "/" {
return false;
}

let normalized_endpoint_path = format!("/{}", endpoint_path.trim_start_matches('/'));

path == endpoint_path || path.ends_with(&normalized_endpoint_path)
}

fn should_force_identity_encoding(
endpoint: &str,
body: &Value,
Expand Down Expand Up @@ -1881,4 +1945,24 @@ mod tests {
assert_eq!(will_replace, should_replace, "{desc}");
}
}

#[test]
fn legacy_full_url_detection_matches_endpoint_path() {
assert!(is_legacy_full_url_for_endpoint(
"https://proxy.example.com/chat/completions",
"/chat/completions"
));
assert!(is_legacy_full_url_for_endpoint(
"https://proxy.example.com/api/v1/messages",
"/v1/messages"
));
assert!(!is_legacy_full_url_for_endpoint(
"https://proxy.example.com/api",
"/chat/completions"
));
assert!(!is_legacy_full_url_for_endpoint(
"https://proxy.example.com/v1",
"/v1/messages"
));
}
}
6 changes: 6 additions & 0 deletions src-tauri/src/proxy/handler_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pub struct RequestContext {
pub app_type: AppType,
/// Session ID(从客户端请求提取或新生成)
pub session_id: String,
/// 客户端显式提供的 Session ID(仅当请求携带对话连续性线索时存在)
pub client_session_id: Option<String>,
/// 整流器配置
pub rectifier_config: RectifierConfig,
/// 优化器配置
Expand Down Expand Up @@ -113,6 +115,9 @@ impl RequestContext {
// 提取 Session ID
let session_result = extract_session_id(headers, body, app_type_str);
let session_id = session_result.session_id.clone();
let client_session_id = session_result
.client_provided
.then(|| session_result.session_id.clone());

log::debug!(
"[{}] Session ID: {} (from {:?}, client_provided: {})",
Expand Down Expand Up @@ -161,6 +166,7 @@ impl RequestContext {
app_type_str,
app_type,
session_id,
client_session_id,
rectifier_config,
optimizer_config,
copilot_optimizer_config,
Expand Down
Loading
Loading