Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src-tauri/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ indexmap = { version = "2", features = ["serde"] }
rust_decimal = "1.33"
uuid = { version = "1.11", features = ["v4"] }
sha2 = "0.10"
mac_address = "1.1"
rand = "0.8"
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增了 rand 依赖,但当前仓库内未找到 rand:: 的引用(PR 中的 jitter 也已改用 SystemTime 生成)。如果没有其它未展示的用途,建议移除该依赖,避免增加构建体积和依赖面。

Suggested change
rand = "0.8"

Copilot uses AI. Check for mistakes.
json5 = "0.4"
json-five = "0.3.1"

Expand Down
13 changes: 11 additions & 2 deletions src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,17 @@ pub fn run() {
use tokio::sync::RwLock;

let app_config_dir = crate::config::get_app_config_dir();
let copilot_auth_manager = CopilotAuthManager::new(app_config_dir);
app.manage(CopilotAuthState(Arc::new(RwLock::new(copilot_auth_manager))));
let copilot_auth_manager = Arc::new(RwLock::new(CopilotAuthManager::new(
app_config_dir,
)));
app.manage(CopilotAuthState(Arc::clone(&copilot_auth_manager)));
tauri::async_runtime::spawn(async move {
copilot_auth_manager
.read()
.await
.initialize_background_tasks()
.await;
});
log::info!("✓ CopilotAuthManager initialized");
}

Expand Down
180 changes: 134 additions & 46 deletions src-tauri/src/proxy/forwarder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ pub struct ForwardError {
pub provider: Option<Provider>,
}

pub struct ForwardRequestInput {
pub endpoint: String,
pub body: Value,
pub headers: axum::http::HeaderMap,
pub extensions: Extensions,
pub client_session_id: Option<String>,
}

struct ForwardMetadata<'a> {
headers: &'a axum::http::HeaderMap,
extensions: &'a Extensions,
client_session_id: Option<&'a str>,
}

pub struct RequestForwarder {
/// 共享的 ProviderRouter(持有熔断器状态)
router: Arc<ProviderRouter>,
Expand Down Expand Up @@ -99,12 +113,16 @@ impl RequestForwarder {
pub async fn forward_with_retry(
&self,
app_type: &AppType,
endpoint: &str,
body: Value,
headers: axum::http::HeaderMap,
extensions: Extensions,
request: ForwardRequestInput,
providers: Vec<Provider>,
) -> Result<ForwardResult, ForwardError> {
let ForwardRequestInput {
endpoint,
body,
headers,
extensions,
client_session_id,
} = request;
// 获取适配器
let adapter = get_adapter(app_type);
let app_type_str = app_type.as_str();
Expand Down Expand Up @@ -176,10 +194,13 @@ impl RequestForwarder {
match self
.forward(
provider,
endpoint,
&endpoint,
&provider_body,
&headers,
&extensions,
ForwardMetadata {
headers: &headers,
extensions: &extensions,
client_session_id: client_session_id.as_deref(),
},
adapter.as_ref(),
)
.await
Expand Down Expand Up @@ -306,10 +327,13 @@ impl RequestForwarder {
match self
.forward(
provider,
endpoint,
&endpoint,
&provider_body,
&headers,
&extensions,
ForwardMetadata {
headers: &headers,
extensions: &extensions,
client_session_id: client_session_id.as_deref(),
},
adapter.as_ref(),
)
.await
Expand Down Expand Up @@ -505,10 +529,13 @@ impl RequestForwarder {
match self
.forward(
provider,
endpoint,
&endpoint,
&provider_body,
&headers,
&extensions,
ForwardMetadata {
headers: &headers,
extensions: &extensions,
client_session_id: client_session_id.as_deref(),
},
adapter.as_ref(),
)
.await
Expand Down Expand Up @@ -746,14 +773,16 @@ impl RequestForwarder {
provider: &Provider,
endpoint: &str,
body: &Value,
headers: &axum::http::HeaderMap,
extensions: &Extensions,
metadata: ForwardMetadata<'_>,
adapter: &dyn ProviderAdapter,
) -> Result<(ProxyResponse, Option<String>), ProxyError> {
let headers = metadata.headers;
let extensions = metadata.extensions;
let client_session_id = metadata.client_session_id;
// 使用适配器提取 base_url
let mut base_url = adapter.extract_base_url(provider)?;

let is_full_url = provider
let has_explicit_full_url = provider
.meta
.as_ref()
.and_then(|meta| meta.is_full_url)
Expand Down Expand Up @@ -830,7 +859,35 @@ impl RequestForwarder {
} else {
None
};
let resolved_claude_api_format = if adapter.name() == "Claude" {
Some(
self.resolve_claude_api_format(provider, &mapped_body, is_copilot)
.await,
)
} else {
None
};
let needs_transform = match resolved_claude_api_format.as_deref() {
Some(api_format) => super::providers::claude_api_format_needs_transform(api_format),
None => adapter.needs_transform(provider),
};
let (effective_endpoint, passthrough_query) =
if needs_transform && adapter.name() == "Claude" {
let api_format = resolved_claude_api_format
.as_deref()
.unwrap_or_else(|| super::providers::get_claude_api_format(provider));
rewrite_claude_transform_endpoint(endpoint, api_format, is_copilot)
} else {
(
endpoint.to_string(),
split_endpoint_and_query(endpoint)
.1
.map(ToString::to_string),
)
};

let is_full_url = has_explicit_full_url
|| is_legacy_full_url_for_endpoint(&base_url, &effective_endpoint);
// GitHub Copilot 动态 endpoint 路由
// 从 CopilotAuthManager 获取缓存的 API endpoint(支持企业版等非默认 endpoint)
if is_copilot && !is_full_url {
Expand Down Expand Up @@ -860,32 +917,6 @@ impl RequestForwarder {
}
}
}
let resolved_claude_api_format = if adapter.name() == "Claude" {
Some(
self.resolve_claude_api_format(provider, &mapped_body, is_copilot)
.await,
)
} else {
None
};
let needs_transform = match resolved_claude_api_format.as_deref() {
Some(api_format) => super::providers::claude_api_format_needs_transform(api_format),
None => adapter.needs_transform(provider),
};
let (effective_endpoint, passthrough_query) =
if needs_transform && adapter.name() == "Claude" {
let api_format = resolved_claude_api_format
.as_deref()
.unwrap_or_else(|| super::providers::get_claude_api_format(provider));
rewrite_claude_transform_endpoint(endpoint, api_format, is_copilot)
} else {
(
endpoint.to_string(),
split_endpoint_and_query(endpoint)
.1
.map(ToString::to_string),
)
};

let url = if is_full_url {
append_query_to_full_url(&base_url, passthrough_query.as_deref())
Expand Down Expand Up @@ -946,10 +977,27 @@ impl RequestForwarder {

match token_result {
Ok(token) => {
auth = AuthInfo::new(token, AuthStrategy::GitHubCopilot);
// 获取 machine ID 和 session ID
let (machine_id, session_id) = copilot_auth
.get_account_ids_for_conversation(
account_id.as_deref(),
client_session_id,
)
.await;

// 使用新的构造函数
auth = AuthInfo::new_with_ids(
token,
AuthStrategy::GitHubCopilot,
machine_id.clone(),
session_id.clone(),
);

log::debug!(
"[Copilot] 成功获取 Copilot token (account={})",
account_id.as_deref().unwrap_or("default")
"[Copilot] 成功获取 Copilot token (account={}, machine_id={}, session_id={})",
account_id.as_deref().unwrap_or("default"),
machine_id.as_ref().map(|s| &s[..16]).unwrap_or("none"),
session_id.as_ref().map(|s| &s[..16]).unwrap_or("none"),
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

日志里输出 machine_id/session_id(即使截断到 16 位)仍是稳定标识符,可能在 debug 日志收集/上传时造成隐私泄露风险。建议改为仅记录是否存在(例如 has_machine_id/has_session_id),或进一步做不可逆脱敏(例如固定长度的 hash 前缀)并确保默认日志级别不会包含这些字段。

Suggested change
"[Copilot] 成功获取 Copilot token (account={}, machine_id={}, session_id={})",
account_id.as_deref().unwrap_or("default"),
machine_id.as_ref().map(|s| &s[..16]).unwrap_or("none"),
session_id.as_ref().map(|s| &s[..16]).unwrap_or("none"),
"[Copilot] 成功获取 Copilot token (account={}, has_machine_id={}, has_session_id={})",
account_id.as_deref().unwrap_or("default"),
machine_id.is_some(),
session_id.is_some(),

Copilot uses AI. Check for mistakes.
);
}
Err(e) => {
Expand All @@ -969,7 +1017,7 @@ impl RequestForwarder {
));
}
}
adapter.get_auth_headers(&auth)
adapter.get_auth_headers(&auth, Some(&filtered_body))
} else {
Vec::new()
};
Expand Down Expand Up @@ -1008,6 +1056,9 @@ impl RequestForwarder {
"x-vscode-user-agent-library-version",
"x-request-id",
"x-agent-task-id",
// Machine ID 和 Session ID
"vscode-machineid",
"vscode-sessionid",
]
} else {
&[]
Expand Down Expand Up @@ -1573,6 +1624,23 @@ fn append_query_to_full_url(base_url: &str, query: Option<&str>) -> String {
}
}

fn is_legacy_full_url_for_endpoint(base_url: &str, endpoint: &str) -> bool {
let Ok(parsed) = url::Url::parse(base_url) else {
return false;
};

let path = parsed.path().trim_end_matches('/');
let endpoint_path = split_endpoint_and_query(endpoint).0.trim_end_matches('/');

if path.is_empty() || path == "/" || endpoint_path.is_empty() || endpoint_path == "/" {
return false;
}

let normalized_endpoint_path = format!("/{}", endpoint_path.trim_start_matches('/'));

path == endpoint_path || path.ends_with(&normalized_endpoint_path)
}

fn should_force_identity_encoding(
endpoint: &str,
body: &Value,
Expand Down Expand Up @@ -1881,4 +1949,24 @@ mod tests {
assert_eq!(will_replace, should_replace, "{desc}");
}
}

#[test]
fn legacy_full_url_detection_matches_endpoint_path() {
assert!(is_legacy_full_url_for_endpoint(
"https://proxy.example.com/chat/completions",
"/chat/completions"
));
assert!(is_legacy_full_url_for_endpoint(
"https://proxy.example.com/api/v1/messages",
"/v1/messages"
));
assert!(!is_legacy_full_url_for_endpoint(
"https://proxy.example.com/api",
"/chat/completions"
));
assert!(!is_legacy_full_url_for_endpoint(
"https://proxy.example.com/v1",
"/v1/messages"
));
}
}
6 changes: 6 additions & 0 deletions src-tauri/src/proxy/handler_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pub struct RequestContext {
pub app_type: AppType,
/// Session ID(从客户端请求提取或新生成)
pub session_id: String,
/// 客户端显式提供的 Session ID(仅当请求携带对话连续性线索时存在)
pub client_session_id: Option<String>,
/// 整流器配置
pub rectifier_config: RectifierConfig,
/// 优化器配置
Expand Down Expand Up @@ -113,6 +115,9 @@ impl RequestContext {
// 提取 Session ID
let session_result = extract_session_id(headers, body, app_type_str);
let session_id = session_result.session_id.clone();
let client_session_id = session_result
.client_provided
.then(|| session_result.session_id.clone());

log::debug!(
"[{}] Session ID: {} (from {:?}, client_provided: {})",
Expand Down Expand Up @@ -161,6 +166,7 @@ impl RequestContext {
app_type_str,
app_type,
session_id,
client_session_id,
rectifier_config,
optimizer_config,
copilot_optimizer_config,
Expand Down
Loading
Loading