diff --git a/src-tauri/src/provider.rs b/src-tauri/src/provider.rs index 67616853c..d8a4b6f14 100644 --- a/src-tauri/src/provider.rs +++ b/src-tauri/src/provider.rs @@ -299,6 +299,54 @@ pub struct ProviderMeta { /// 用于多账号支持,关联到特定的 GitHub 账号 #[serde(rename = "githubAccountId", skip_serializing_if = "Option::is_none")] pub github_account_id: Option, + /// 模型路由配置(支持单个 Provider 根据请求模型动态切换 API 格式和目标端点) + #[serde(rename = "modelRoutingConfig", skip_serializing_if = "Option::is_none")] + pub model_routing_config: Option, +} + +/// 模型路由配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelRoutingConfig { + /// 是否启用路由配置 + pub enabled: bool, + /// 路由规则列表 + #[serde(default)] + pub routes: Vec, + /// 兜底配置(当没有匹配的路由时使用) + #[serde(skip_serializing_if = "Option::is_none")] + pub fallback: Option, +} + +/// 单个模型路由规则 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelRoute { + /// 源模型名称(不区分大小写匹配) + #[serde(rename = "sourceModel")] + pub source_model: String, + /// 路由目标配置 + pub target: RouteTarget, +} + +/// 路由目标配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouteTarget { + /// 目标 Base URL + #[serde(rename = "baseUrl")] + pub base_url: String, + /// 目标 API 格式 + #[serde(rename = "apiFormat")] + pub api_format: String, + /// 目标模型名称 + #[serde(rename = "modelName")] + pub model_name: String, +} + +/// 路由兜底配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouteFallback { + /// 兜底 API 格式 + #[serde(rename = "apiFormat")] + pub api_format: String, } impl ProviderMeta { diff --git a/src-tauri/src/proxy/forwarder.rs b/src-tauri/src/proxy/forwarder.rs index c4172ea3d..5f9fc1cec 100644 --- a/src-tauri/src/proxy/forwarder.rs +++ b/src-tauri/src/proxy/forwarder.rs @@ -760,6 +760,9 @@ impl RequestForwarder { .and_then(|meta| meta.is_full_url) .unwrap_or(false); + // 提取原始模型名称(在模型映射之前,用于模型路由) + let original_model = body.get("model").and_then(|m| m.as_str()); + // 应用模型映射(独立于格式转换) let (mapped_body, _original_model, _mapped_model) = super::model_mapper::apply_model_mapping(body.clone(), provider); @@ -861,9 +864,20 @@ impl RequestForwarder { } } } + + // Claude 模型路由:根据请求模型动态修改 base_url + // 只在非 Copilot 且非 full_url 模式下处理 + if adapter.name() == "Claude" && !is_copilot && !is_full_url { + base_url = crate::proxy::providers::get_routed_base_url( + provider, + original_model, + &base_url, + ); + } + let resolved_claude_api_format = if adapter.name() == "Claude" { Some( - self.resolve_claude_api_format(provider, &mapped_body, is_copilot) + self.resolve_claude_api_format(provider, &mapped_body, is_copilot, original_model) .await, ) } else { @@ -1362,9 +1376,12 @@ impl RequestForwarder { provider: &Provider, body: &Value, is_copilot: bool, + original_model: Option<&str>, ) -> String { if !is_copilot { - return super::providers::get_claude_api_format(provider).to_string(); + // 使用支持模型路由的 get_claude_api_format_with_model + return crate::proxy::providers::get_claude_api_format_with_model(provider, original_model) + .to_string(); } let model = body.get("model").and_then(|value| value.as_str()); diff --git a/src-tauri/src/proxy/providers/claude.rs b/src-tauri/src/proxy/providers/claude.rs index 52689d3b5..06cc65af6 100644 --- a/src-tauri/src/proxy/providers/claude.rs +++ b/src-tauri/src/proxy/providers/claude.rs @@ -20,8 +20,17 @@ use crate::proxy::error::ProxyError; /// 获取 Claude 供应商的 API 格式 /// /// 供 handler/forwarder 外部使用的公开函数。 -/// 优先级:meta.apiFormat > settings_config.api_format > openrouter_compat_mode > 默认 "anthropic" +/// 优先级:model_routing_config > meta.apiFormat > settings_config.api_format > openrouter_compat_mode > 默认 "anthropic" pub fn get_claude_api_format(provider: &Provider) -> &'static str { + get_claude_api_format_with_model(provider, None) +} + +/// 获取 Claude 供应商的 API 格式(支持模型路由) +/// +/// # Arguments +/// * `provider` - Provider 实例 +/// * `model_name` - 请求中的模型名称(可选,用于模型路由匹配) +pub fn get_claude_api_format_with_model(provider: &Provider, model_name: Option<&str>) -> &'static str { // 0) Codex OAuth 强制使用 openai_responses(不可被覆盖) if let Some(meta) = provider.meta.as_ref() { if meta.provider_type.as_deref() == Some("codex_oauth") { @@ -29,7 +38,27 @@ pub fn get_claude_api_format(provider: &Provider) -> &'static str { } } - // 1) Preferred: meta.apiFormat (SSOT, never written to Claude Code config) + // 1) 检查 model_routing_config(支持根据请求模型动态选择 API 格式) + if let Some(meta) = provider.meta.as_ref() { + if let Some(ref routing_config) = meta.model_routing_config { + if routing_config.enabled { + if let Some(model) = model_name { + // 大小写不敏感匹配 + if let Some(route) = routing_config.routes.iter().find(|r| { + r.source_model.eq_ignore_ascii_case(model) + }) { + return Box::leak(route.target.api_format.clone().into_boxed_str()); + } + } + // 匹配不到时使用 fallback + if let Some(ref fallback) = routing_config.fallback { + return Box::leak(fallback.api_format.clone().into_boxed_str()); + } + } + } + } + + // 2) Preferred: meta.apiFormat (SSOT, never written to Claude Code config) if let Some(meta) = provider.meta.as_ref() { if let Some(api_format) = meta.api_format.as_deref() { return match api_format { @@ -40,7 +69,7 @@ pub fn get_claude_api_format(provider: &Provider) -> &'static str { } } - // 2) Backward compatibility: legacy settings_config.api_format + // 3) Backward compatibility: legacy settings_config.api_format if let Some(api_format) = provider .settings_config .get("api_format") @@ -53,7 +82,7 @@ pub fn get_claude_api_format(provider: &Provider) -> &'static str { }; } - // 3) Backward compatibility: legacy openrouter_compat_mode (bool/number/string) + // 4) Backward compatibility: legacy openrouter_compat_mode (bool/number/string) let raw = provider.settings_config.get("openrouter_compat_mode"); let enabled = match raw { Some(serde_json::Value::Bool(v)) => *v, @@ -72,6 +101,32 @@ pub fn get_claude_api_format(provider: &Provider) -> &'static str { } } +/// 根据模型路由获取目标 Base URL +/// +/// # Arguments +/// * `provider` - Provider 实例 +/// * `model_name` - 请求中的模型名称(用于路由匹配) +/// * `default_url` - 默认的 Base URL(Provider 配置的原始 URL) +pub fn get_routed_base_url(provider: &Provider, model_name: Option<&str>, default_url: &str) -> String { + // 检查 model_routing_config + if let Some(meta) = provider.meta.as_ref() { + if let Some(ref routing_config) = meta.model_routing_config { + if routing_config.enabled { + if let Some(model) = model_name { + // 大小写不敏感匹配 + if let Some(route) = routing_config.routes.iter().find(|r| { + r.source_model.eq_ignore_ascii_case(model) + }) { + return route.target.base_url.clone(); + } + } + } + } + } + // 返回默认 URL + default_url.to_string() +} + pub fn claude_api_format_needs_transform(api_format: &str) -> bool { matches!(api_format, "openai_chat" | "openai_responses") } diff --git a/src-tauri/src/proxy/providers/mod.rs b/src-tauri/src/proxy/providers/mod.rs index b6b8bb643..65bba02dc 100644 --- a/src-tauri/src/proxy/providers/mod.rs +++ b/src-tauri/src/proxy/providers/mod.rs @@ -33,6 +33,7 @@ pub use adapter::ProviderAdapter; pub use auth::{AuthInfo, AuthStrategy}; pub use claude::{ claude_api_format_needs_transform, get_claude_api_format, + get_claude_api_format_with_model, get_routed_base_url, transform_claude_request_for_api_format, ClaudeAdapter, }; pub use codex::CodexAdapter; diff --git a/src/components/providers/forms/ModelRoutingConfigPanel.tsx b/src/components/providers/forms/ModelRoutingConfigPanel.tsx new file mode 100644 index 000000000..020bc34ae --- /dev/null +++ b/src/components/providers/forms/ModelRoutingConfigPanel.tsx @@ -0,0 +1,430 @@ +import { useTranslation } from "react-i18next"; +import { useState, useEffect } from "react"; +import { + ChevronDown, + ChevronRight, + Route, + Plus, + Trash2, + Copy, + ArrowRight, +} from "lucide-react"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Switch } from "@/components/ui/switch"; +import { Button } from "@/components/ui/button"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { cn } from "@/lib/utils"; +import type { + ModelRoutingConfig, + ModelRoute, + RouteTarget, + RouteFallback, +} from "@/types"; + +interface ModelRoutingConfigPanelProps { + config: ModelRoutingConfig; + onConfigChange: (config: ModelRoutingConfig) => void; +} + +export function ModelRoutingConfigPanel({ + config, + onConfigChange, +}: ModelRoutingConfigPanelProps) { + const { t } = useTranslation(); + const [isOpen, setIsOpen] = useState(config.enabled); + + useEffect(() => { + setIsOpen(config.enabled); + }, [config.enabled]); + + const handleEnabledChange = (enabled: boolean) => { + onConfigChange({ ...config, enabled }); + if (enabled) setIsOpen(true); + }; + + const handleAddRoute = () => { + const newRoute: ModelRoute = { + sourceModel: "", + target: { + baseUrl: "", + apiFormat: "anthropic", + modelName: "", + }, + }; + onConfigChange({ + ...config, + routes: [...config.routes, newRoute], + }); + }; + + const handleUpdateRoute = (index: number, route: ModelRoute) => { + const newRoutes = [...config.routes]; + newRoutes[index] = route; + onConfigChange({ ...config, routes: newRoutes }); + }; + + const handleDeleteRoute = (index: number) => { + onConfigChange({ + ...config, + routes: config.routes.filter((_, i) => i !== index), + }); + }; + + const handleDuplicateRoute = (index: number) => { + const routeToDuplicate = config.routes[index]; + const newRoutes = [...config.routes]; + newRoutes.splice(index + 1, 0, { ...routeToDuplicate }); + onConfigChange({ ...config, routes: newRoutes }); + }; + + const handleFallbackChange = (apiFormat: string) => { + onConfigChange({ + ...config, + fallback: { apiFormat: apiFormat as RouteFallback["apiFormat"] }, + }); + }; + + return ( +
+ +
+
+

+ {t("providerAdvanced.modelRoutingConfigDesc", { + defaultValue: + "配置模型路由规则,根据请求中的模型名称自动切换到不同的 API 格式和目标端点。例如:将 Claude Opus 路由到 Vertex Gemini,将 Claude Sonnet 路由到 Gemini Flash。", + })} +

+ + {/* 路由规则列表 */} +
+
+ + +
+ + {config.routes.length === 0 && ( +
+ +

+ {t("providerAdvanced.noRoutesYet", { + defaultValue: "暂无路由规则,点击上方按钮添加", + })} +

+
+ )} + + {config.routes.map((route, index) => ( + handleUpdateRoute(index, updated)} + onDelete={() => handleDeleteRoute(index)} + onDuplicate={() => handleDuplicateRoute(index)} + /> + ))} +
+ + {/* 兜底配置 */} +
+ +

+ {t("providerAdvanced.fallbackConfigDesc", { + defaultValue: + "当请求模型没有匹配到任何路由规则时,使用兜底配置。留空则使用供应商默认 API 格式。", + })} +

+
+ + +
+
+
+
+
+ ); +} + +interface RouteItemProps { + route: ModelRoute; + index: number; + disabled: boolean; + onUpdate: (route: ModelRoute) => void; + onDelete: () => void; + onDuplicate: () => void; +} + +function RouteItem({ + route, + index, + disabled, + onUpdate, + onDelete, + onDuplicate, +}: RouteItemProps) { + const { t } = useTranslation(); + + return ( +
+
+ +
+ + +
+
+ + {/* 源模型 */} +
+ + + onUpdate({ ...route, sourceModel: e.target.value }) + } + placeholder={t("providerAdvanced.sourceModelPlaceholder", { + defaultValue: "例如:claude-opus-4-5", + })} + disabled={disabled} + className="font-mono" + /> +

+ {t("providerAdvanced.sourceModelHint", { + defaultValue: "不区分大小写匹配,例如 claude-opus-4-5 可以匹配 Claude Opus 4.5", + })} +

+
+ + {/* 箭头指示 */} +
+ +
+ + {/* 目标配置 */} +
+ + +
+
+ + + onUpdate({ + ...route, + target: { ...route.target, baseUrl: e.target.value }, + }) + } + placeholder="https://api.vertexai.example.com" + disabled={disabled} + className="font-mono text-xs" + /> +
+ +
+ + +
+ +
+ + + onUpdate({ + ...route, + target: { ...route.target, modelName: e.target.value }, + }) + } + placeholder="google/gemini-2.0-flash" + disabled={disabled} + className="font-mono text-xs" + /> +
+
+
+
+ ); +} diff --git a/src/components/providers/forms/ProviderAdvancedConfig.tsx b/src/components/providers/forms/ProviderAdvancedConfig.tsx index 6023bd5f8..97c0cd6df 100644 --- a/src/components/providers/forms/ProviderAdvancedConfig.tsx +++ b/src/components/providers/forms/ProviderAdvancedConfig.tsx @@ -22,7 +22,12 @@ import { SelectValue, } from "@/components/ui/select"; import { cn } from "@/lib/utils"; -import type { ProviderTestConfig, ProviderProxyConfig } from "@/types"; +import type { + ProviderTestConfig, + ProviderProxyConfig, + ModelRoutingConfig, +} from "@/types"; +import { ModelRoutingConfigPanel } from "./ModelRoutingConfigPanel"; export type PricingModelSourceOption = "inherit" | "request" | "response"; @@ -36,9 +41,11 @@ interface ProviderAdvancedConfigProps { testConfig: ProviderTestConfig; proxyConfig: ProviderProxyConfig; pricingConfig: ProviderPricingConfig; + modelRoutingConfig: ModelRoutingConfig; onTestConfigChange: (config: ProviderTestConfig) => void; onProxyConfigChange: (config: ProviderProxyConfig) => void; onPricingConfigChange: (config: ProviderPricingConfig) => void; + onModelRoutingConfigChange: (config: ModelRoutingConfig) => void; } /** 从 ProviderProxyConfig 构建完整 URL */ @@ -90,9 +97,11 @@ export function ProviderAdvancedConfig({ testConfig, proxyConfig, pricingConfig, + modelRoutingConfig, onTestConfigChange, onProxyConfigChange, onPricingConfigChange, + onModelRoutingConfigChange, }: ProviderAdvancedConfigProps) { const { t } = useTranslation(); const [isTestConfigOpen, setIsTestConfigOpen] = useState(testConfig.enabled); @@ -124,6 +133,11 @@ export function ProviderAdvancedConfig({ setIsPricingConfigOpen(pricingConfig.enabled); }, [pricingConfig.enabled]); + // 同步外部 pricingConfig.enabled 变化到展开状态 + useEffect(() => { + setIsPricingConfigOpen(pricingConfig.enabled); + }, [pricingConfig.enabled]); + // 仅在外部 proxyConfig 变化且非用户输入时同步(如:重置表单、加载数据) useEffect(() => { if (!isUserTyping) { @@ -166,6 +180,11 @@ export function ProviderAdvancedConfig({ return (
+ +