From 7b7b258c38729b0924c6300aba8e77912d48e31b Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 11 Mar 2026 10:47:33 +0800 Subject: [PATCH 01/13] Fixed: #2022 test(translator): add tests for handling Claude system messages as string and array --- .../codex/claude/codex_claude_request.go | 33 ++++--- .../codex/claude/codex_claude_request_test.go | 89 +++++++++++++++++++ 2 files changed, 110 insertions(+), 12 deletions(-) create mode 100644 internal/translator/codex/claude/codex_claude_request_test.go diff --git a/internal/translator/codex/claude/codex_claude_request.go b/internal/translator/codex/claude/codex_claude_request.go index 6373e69336..4bc116b9fb 100644 --- a/internal/translator/codex/claude/codex_claude_request.go +++ b/internal/translator/codex/claude/codex_claude_request.go @@ -43,23 +43,32 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) // Process system messages and convert them to input content format. systemsResult := rootResult.Get("system") - if systemsResult.IsArray() { - systemResults := systemsResult.Array() + if systemsResult.Exists() { message := `{"type":"message","role":"developer","content":[]}` contentIndex := 0 - for i := 0; i < len(systemResults); i++ { - systemResult := systemResults[i] - systemTypeResult := systemResult.Get("type") - if systemTypeResult.String() == "text" { - text := systemResult.Get("text").String() - if strings.HasPrefix(text, "x-anthropic-billing-header: ") { - continue + + appendSystemText := func(text string) { + if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") { + return + } + + message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text") + message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text) + contentIndex++ + } + + if systemsResult.Type == gjson.String { + appendSystemText(systemsResult.String()) + } else if systemsResult.IsArray() { + systemResults := systemsResult.Array() + for i := 0; i < len(systemResults); i++ { + systemResult := systemResults[i] + if systemResult.Get("type").String() == "text" { + appendSystemText(systemResult.Get("text").String()) } - message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text") - message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text) - contentIndex++ } } + if contentIndex > 0 { template, _ = sjson.SetRaw(template, "input.-1", message) } diff --git a/internal/translator/codex/claude/codex_claude_request_test.go b/internal/translator/codex/claude/codex_claude_request_test.go new file mode 100644 index 0000000000..bdd41639c1 --- /dev/null +++ b/internal/translator/codex/claude/codex_claude_request_test.go @@ -0,0 +1,89 @@ +package claude + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) { + tests := []struct { + name string + inputJSON string + wantHasDeveloper bool + wantTexts []string + }{ + { + name: "No system field", + inputJSON: `{ + "model": "claude-3-opus", + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasDeveloper: false, + }, + { + name: "Empty string system field", + inputJSON: `{ + "model": "claude-3-opus", + "system": "", + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasDeveloper: false, + }, + { + name: "String system field", + inputJSON: `{ + "model": "claude-3-opus", + "system": "Be helpful", + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasDeveloper: true, + wantTexts: []string{"Be helpful"}, + }, + { + name: "Array system field with filtered billing header", + inputJSON: `{ + "model": "claude-3-opus", + "system": [ + {"type": "text", "text": "x-anthropic-billing-header: tenant-123"}, + {"type": "text", "text": "Block 1"}, + {"type": "text", "text": "Block 2"} + ], + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasDeveloper: true, + wantTexts: []string{"Block 1", "Block 2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false) + resultJSON := gjson.ParseBytes(result) + inputs := resultJSON.Get("input").Array() + + hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer" + if hasDeveloper != tt.wantHasDeveloper { + t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw) + } + + if !tt.wantHasDeveloper { + return + } + + content := inputs[0].Get("content").Array() + if len(content) != len(tt.wantTexts) { + t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw) + } + + for i, wantText := range tt.wantTexts { + if gotType := content[i].Get("type").String(); gotType != "input_text" { + t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text") + } + if gotText := content[i].Get("text").String(); gotText != wantText { + t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText) + } + } + }) + } +} From ddaa9d2436e862146fe099d7d9dc06238b3c6ec4 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 11 Mar 2026 11:08:02 +0800 Subject: [PATCH 02/13] Fixed: #2034 feat(proxy): centralize proxy handling with `proxyutil` package and enhance test coverage - Added `proxyutil` package to simplify proxy handling across the codebase. - Refactored various components (`executor`, `cliproxy`, `auth`, etc.) to use `proxyutil` for consistent and reusable proxy logic. - Introduced support for "direct" proxy mode to explicitly bypass all proxies. - Updated tests to validate proxy behavior (e.g., `direct`, HTTP/HTTPS, and SOCKS5). - Enhanced YAML configuration documentation for proxy options. --- config.example.yaml | 6 + internal/api/handlers/management/api_tools.go | 46 +---- .../api/handlers/management/api_tools_test.go | 177 +++--------------- .../handlers/management/test_store_test.go | 49 +++++ internal/auth/claude/utls_transport.go | 19 +- internal/auth/gemini/gemini_auth.go | 40 +--- .../executor/codex_websockets_executor.go | 28 ++- .../codex_websockets_executor_test.go | 16 ++ internal/runtime/executor/proxy_helpers.go | 45 +---- .../runtime/executor/proxy_helpers_test.go | 30 +++ internal/util/proxy.go | 41 +--- sdk/cliproxy/rtprovider.go | 36 +--- sdk/cliproxy/rtprovider_test.go | 22 +++ sdk/proxyutil/proxy.go | 139 ++++++++++++++ sdk/proxyutil/proxy_test.go | 89 +++++++++ 15 files changed, 439 insertions(+), 344 deletions(-) create mode 100644 internal/api/handlers/management/test_store_test.go create mode 100644 internal/runtime/executor/proxy_helpers_test.go create mode 100644 sdk/cliproxy/rtprovider_test.go create mode 100644 sdk/proxyutil/proxy.go create mode 100644 sdk/proxyutil/proxy_test.go diff --git a/config.example.yaml b/config.example.yaml index 348aabd846..a75b69f0f1 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -63,6 +63,7 @@ error-logs-max-files: 10 usage-statistics-enabled: false # Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/ +# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly. proxy-url: "" # When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name). @@ -110,6 +111,7 @@ nonstream-keepalive-interval: 0 # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" +# # proxy-url: "direct" # optional: explicit direct connect for this credential # models: # - name: "gemini-2.5-flash" # upstream model name # alias: "gemini-flash" # client alias mapped to the upstream model @@ -128,6 +130,7 @@ nonstream-keepalive-interval: 0 # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override +# # proxy-url: "direct" # optional: explicit direct connect for this credential # models: # - name: "gpt-5-codex" # upstream model name # alias: "codex-latest" # client alias mapped to the upstream model @@ -146,6 +149,7 @@ nonstream-keepalive-interval: 0 # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override +# # proxy-url: "direct" # optional: explicit direct connect for this credential # models: # - name: "claude-3-5-sonnet-20241022" # upstream model name # alias: "claude-sonnet-latest" # client alias mapped to the upstream model @@ -183,6 +187,7 @@ nonstream-keepalive-interval: 0 # api-key-entries: # - api-key: "sk-or-v1-...b780" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override +# # proxy-url: "direct" # optional: explicit direct connect for this credential # - api-key: "sk-or-v1-...b781" # without proxy-url # models: # The models supported by the provider. # - name: "moonshotai/kimi-k2:free" # The actual model name. @@ -205,6 +210,7 @@ nonstream-keepalive-interval: 0 # prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential # base-url: "https://example.com/api" # e.g. https://zenmux.ai/api # proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override +# # proxy-url: "direct" # optional: explicit direct connect for this credential # headers: # X-Custom-Header: "custom-value" # models: # optional: map aliases to upstream model names diff --git a/internal/api/handlers/management/api_tools.go b/internal/api/handlers/management/api_tools.go index c7846a7599..de546ea820 100644 --- a/internal/api/handlers/management/api_tools.go +++ b/internal/api/handlers/management/api_tools.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/url" "strings" @@ -14,8 +13,8 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" - "golang.org/x/net/proxy" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) @@ -660,45 +659,10 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { } func buildProxyTransport(proxyStr string) *http.Transport { - proxyStr = strings.TrimSpace(proxyStr) - if proxyStr == "" { + transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr) + if errBuild != nil { + log.WithError(errBuild).Debug("build proxy transport failed") return nil } - - proxyURL, errParse := url.Parse(proxyStr) - if errParse != nil { - log.WithError(errParse).Debug("parse proxy URL failed") - return nil - } - if proxyURL.Scheme == "" || proxyURL.Host == "" { - log.Debug("proxy URL missing scheme/host") - return nil - } - - if proxyURL.Scheme == "socks5" { - var proxyAuth *proxy.Auth - if proxyURL.User != nil { - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - proxyAuth = &proxy.Auth{User: username, Password: password} - } - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed") - return nil - } - return &http.Transport{ - Proxy: nil, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } - - if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - return &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } - - log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme) - return nil + return transport } diff --git a/internal/api/handlers/management/api_tools_test.go b/internal/api/handlers/management/api_tools_test.go index fecbee9cb8..5b0c63693a 100644 --- a/internal/api/handlers/management/api_tools_test.go +++ b/internal/api/handlers/management/api_tools_test.go @@ -1,173 +1,58 @@ package management import ( - "context" - "encoding/json" - "io" "net/http" - "net/http/httptest" - "net/url" - "strings" - "sync" "testing" - "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) -type memoryAuthStore struct { - mu sync.Mutex - items map[string]*coreauth.Auth -} +func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) { + t.Parallel() -func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) { - _ = ctx - s.mu.Lock() - defer s.mu.Unlock() - out := make([]*coreauth.Auth, 0, len(s.items)) - for _, a := range s.items { - out = append(out, a.Clone()) + h := &Handler{ + cfg: &config.Config{ + SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}, + }, } - return out, nil -} -func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) { - _ = ctx - if auth == nil { - return "", nil + transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"}) + httpTransport, ok := transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", transport) } - s.mu.Lock() - if s.items == nil { - s.items = make(map[string]*coreauth.Auth) + if httpTransport.Proxy != nil { + t.Fatal("expected direct transport to disable proxy function") } - s.items[auth.ID] = auth.Clone() - s.mu.Unlock() - return auth.ID, nil -} - -func (s *memoryAuthStore) Delete(ctx context.Context, id string) error { - _ = ctx - s.mu.Lock() - delete(s.items, id) - s.mu.Unlock() - return nil } -func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) { - var callCount int - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - if r.Method != http.MethodPost { - t.Fatalf("expected POST, got %s", r.Method) - } - if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") { - t.Fatalf("unexpected content-type: %s", ct) - } - bodyBytes, _ := io.ReadAll(r.Body) - _ = r.Body.Close() - values, err := url.ParseQuery(string(bodyBytes)) - if err != nil { - t.Fatalf("parse form: %v", err) - } - if values.Get("grant_type") != "refresh_token" { - t.Fatalf("unexpected grant_type: %s", values.Get("grant_type")) - } - if values.Get("refresh_token") != "rt" { - t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token")) - } - if values.Get("client_id") != antigravityOAuthClientID { - t.Fatalf("unexpected client_id: %s", values.Get("client_id")) - } - if values.Get("client_secret") != antigravityOAuthClientSecret { - t.Fatalf("unexpected client_secret") - } +func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) { + t.Parallel() - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "access_token": "new-token", - "refresh_token": "rt2", - "expires_in": int64(3600), - "token_type": "Bearer", - }) - })) - t.Cleanup(srv.Close) - - originalURL := antigravityOAuthTokenURL - antigravityOAuthTokenURL = srv.URL - t.Cleanup(func() { antigravityOAuthTokenURL = originalURL }) - - store := &memoryAuthStore{} - manager := coreauth.NewManager(store, nil, nil) - - auth := &coreauth.Auth{ - ID: "antigravity-test.json", - FileName: "antigravity-test.json", - Provider: "antigravity", - Metadata: map[string]any{ - "type": "antigravity", - "access_token": "old-token", - "refresh_token": "rt", - "expires_in": int64(3600), - "timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(), - "expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + h := &Handler{ + cfg: &config.Config{ + SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}, }, } - if _, err := manager.Register(context.Background(), auth); err != nil { - t.Fatalf("register auth: %v", err) - } - h := &Handler{authManager: manager} - token, err := h.resolveTokenForAuth(context.Background(), auth) - if err != nil { - t.Fatalf("resolveTokenForAuth: %v", err) - } - if token != "new-token" { - t.Fatalf("expected refreshed token, got %q", token) - } - if callCount != 1 { - t.Fatalf("expected 1 refresh call, got %d", callCount) + transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"}) + httpTransport, ok := transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", transport) } - updated, ok := manager.GetByID(auth.ID) - if !ok || updated == nil { - t.Fatalf("expected auth in manager after update") - } - if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" { - t.Fatalf("expected manager metadata updated, got %q", got) + req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errRequest != nil { + t.Fatalf("http.NewRequest returned error: %v", errRequest) } -} - -func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) { - var callCount int - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - w.WriteHeader(http.StatusInternalServerError) - })) - t.Cleanup(srv.Close) - originalURL := antigravityOAuthTokenURL - antigravityOAuthTokenURL = srv.URL - t.Cleanup(func() { antigravityOAuthTokenURL = originalURL }) - - auth := &coreauth.Auth{ - ID: "antigravity-valid.json", - FileName: "antigravity-valid.json", - Provider: "antigravity", - Metadata: map[string]any{ - "type": "antigravity", - "access_token": "ok-token", - "expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339), - }, - } - h := &Handler{} - token, err := h.resolveTokenForAuth(context.Background(), auth) - if err != nil { - t.Fatalf("resolveTokenForAuth: %v", err) - } - if token != "ok-token" { - t.Fatalf("expected existing token, got %q", token) + proxyURL, errProxy := httpTransport.Proxy(req) + if errProxy != nil { + t.Fatalf("httpTransport.Proxy returned error: %v", errProxy) } - if callCount != 0 { - t.Fatalf("expected no refresh calls, got %d", callCount) + if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" { + t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL) } } diff --git a/internal/api/handlers/management/test_store_test.go b/internal/api/handlers/management/test_store_test.go new file mode 100644 index 0000000000..cf7dbaf7d0 --- /dev/null +++ b/internal/api/handlers/management/test_store_test.go @@ -0,0 +1,49 @@ +package management + +import ( + "context" + "sync" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +type memoryAuthStore struct { + mu sync.Mutex + items map[string]*coreauth.Auth +} + +func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) { + s.mu.Lock() + defer s.mu.Unlock() + + out := make([]*coreauth.Auth, 0, len(s.items)) + for _, item := range s.items { + out = append(out, item) + } + return out, nil +} + +func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) { + if auth == nil { + return "", nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.items == nil { + s.items = make(map[string]*coreauth.Auth) + } + s.items[auth.ID] = auth + return auth.ID, nil +} + +func (s *memoryAuthStore) Delete(_ context.Context, id string) error { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.items, id) + return nil +} + +func (s *memoryAuthStore) SetBaseDir(string) {} diff --git a/internal/auth/claude/utls_transport.go b/internal/auth/claude/utls_transport.go index 27ec87e136..88b69c9bd9 100644 --- a/internal/auth/claude/utls_transport.go +++ b/internal/auth/claude/utls_transport.go @@ -4,12 +4,12 @@ package claude import ( "net/http" - "net/url" "strings" "sync" tls "github.com/refraction-networking/utls" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" "golang.org/x/net/http2" "golang.org/x/net/proxy" @@ -31,17 +31,12 @@ type utlsRoundTripper struct { // newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper { var dialer proxy.Dialer = proxy.Direct - if cfg != nil && cfg.ProxyURL != "" { - proxyURL, err := url.Parse(cfg.ProxyURL) - if err != nil { - log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err) - } else { - pDialer, err := proxy.FromURL(proxyURL, proxy.Direct) - if err != nil { - log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err) - } else { - dialer = pDialer - } + if cfg != nil { + proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL) + if errBuild != nil { + log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild) + } else if mode != proxyutil.ModeInherit && proxyDialer != nil { + dialer = proxyDialer } } diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index 6406a0e156..c459c5ca33 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -10,9 +10,7 @@ import ( "errors" "fmt" "io" - "net" "net/http" - "net/url" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" @@ -20,9 +18,9 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - "golang.org/x/net/proxy" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken } callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort) - // Configure proxy settings for the HTTP client if a proxy URL is provided. - proxyURL, err := url.Parse(cfg.ProxyURL) - if err == nil { - var transport *http.Transport - if proxyURL.Scheme == "socks5" { - // Handle SOCKS5 proxy. - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - auth := &proxy.Auth{User: username, Password: password} - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) - if errSOCKS5 != nil { - log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) - return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) - } - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Handle HTTP/HTTPS proxy. - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } - - if transport != nil { - proxyClient := &http.Client{Transport: transport} - ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) - } + transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL) + if errBuild != nil { + log.Errorf("%v", errBuild) + } else if transport != nil { + proxyClient := &http.Client{Transport: transport} + ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) } + var err error + // Configure the OAuth2 client. conf := &oauth2.Config{ ClientID: ClientID, diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 1f3400500c..42a9e797b0 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -23,6 +23,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -705,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) * return dialer } - parsedURL, errParse := url.Parse(proxyURL) + setting, errParse := proxyutil.Parse(proxyURL) if errParse != nil { - log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse) + log.Errorf("codex websockets executor: %v", errParse) return dialer } - switch parsedURL.Scheme { + switch setting.Mode { + case proxyutil.ModeDirect: + dialer.Proxy = nil + return dialer + case proxyutil.ModeProxy: + default: + return dialer + } + + switch setting.URL.Scheme { case "socks5": var proxyAuth *proxy.Auth - if parsedURL.User != nil { - username := parsedURL.User.Username() - password, _ := parsedURL.User.Password() + if setting.URL.User != nil { + username := setting.URL.User.Username() + password, _ := setting.URL.User.Password() proxyAuth = &proxy.Auth{User: username, Password: password} } - socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct) + socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct) if errSOCKS5 != nil { log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5) return dialer @@ -729,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) * return socksDialer.Dial(network, addr) } case "http", "https": - dialer.Proxy = http.ProxyURL(parsedURL) + dialer.Proxy = http.ProxyURL(setting.URL) default: - log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme) + log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme) } return dialer diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index 1fd685138c..20d44581d8 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -5,6 +5,9 @@ import ( "net/http" "testing" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "github.com/tidwall/gjson" ) @@ -34,3 +37,16 @@ func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue) } } + +func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) { + t.Parallel() + + dialer := newProxyAwareWebsocketDialer( + &config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}}, + &cliproxyauth.Auth{ProxyURL: "direct"}, + ) + + if dialer.Proxy != nil { + t.Fatal("expected websocket proxy function to be nil for direct mode") + } +} diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index ab0f626acc..5511497b9e 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -2,16 +2,14 @@ package executor import ( "context" - "net" "net/http" - "net/url" "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" - "golang.org/x/net/proxy" ) // newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority: @@ -72,45 +70,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip // Returns: // - *http.Transport: A configured transport, or nil if the proxy URL is invalid func buildProxyTransport(proxyURL string) *http.Transport { - if proxyURL == "" { + transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL) + if errBuild != nil { + log.Errorf("%v", errBuild) return nil } - - parsedURL, errParse := url.Parse(proxyURL) - if errParse != nil { - log.Errorf("parse proxy URL failed: %v", errParse) - return nil - } - - var transport *http.Transport - - // Handle different proxy schemes - if parsedURL.Scheme == "socks5" { - // Configure SOCKS5 proxy with optional authentication - var proxyAuth *proxy.Auth - if parsedURL.User != nil { - username := parsedURL.User.Username() - password, _ := parsedURL.User.Password() - proxyAuth = &proxy.Auth{User: username, Password: password} - } - dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) - return nil - } - // Set up a custom transport using the SOCKS5 dialer - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy - transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)} - } else { - log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) - return nil - } - return transport } diff --git a/internal/runtime/executor/proxy_helpers_test.go b/internal/runtime/executor/proxy_helpers_test.go new file mode 100644 index 0000000000..4ae5c93766 --- /dev/null +++ b/internal/runtime/executor/proxy_helpers_test.go @@ -0,0 +1,30 @@ +package executor + +import ( + "context" + "net/http" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) { + t.Parallel() + + client := newProxyAwareHTTPClient( + context.Background(), + &config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}}, + &cliproxyauth.Auth{ProxyURL: "direct"}, + 0, + ) + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", client.Transport) + } + if transport.Proxy != nil { + t.Fatal("expected direct transport to disable proxy function") + } +} diff --git a/internal/util/proxy.go b/internal/util/proxy.go index aea52ba8ce..9b57ca1733 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -4,50 +4,25 @@ package util import ( - "context" - "net" "net/http" - "net/url" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" - "golang.org/x/net/proxy" ) // SetProxy configures the provided HTTP client with proxy settings from the configuration. // It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport // to route requests through the configured proxy server. func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client { - var transport *http.Transport - // Attempt to parse the proxy URL from the configuration. - proxyURL, errParse := url.Parse(cfg.ProxyURL) - if errParse == nil { - // Handle different proxy schemes. - if proxyURL.Scheme == "socks5" { - // Configure SOCKS5 proxy with optional authentication. - var proxyAuth *proxy.Auth - if proxyURL.User != nil { - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - proxyAuth = &proxy.Auth{User: username, Password: password} - } - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) - return httpClient - } - // Set up a custom transport using the SOCKS5 dialer. - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy. - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } + if cfg == nil || httpClient == nil { + return httpClient + } + + transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL) + if errBuild != nil { + log.Errorf("%v", errBuild) } - // If a new transport was created, apply it to the HTTP client. if transport != nil { httpClient.Transport = transport } diff --git a/sdk/cliproxy/rtprovider.go b/sdk/cliproxy/rtprovider.go index dad4fc2387..5c4f579a85 100644 --- a/sdk/cliproxy/rtprovider.go +++ b/sdk/cliproxy/rtprovider.go @@ -1,16 +1,13 @@ package cliproxy import ( - "context" - "net" "net/http" - "net/url" "strings" "sync" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" - "golang.org/x/net/proxy" ) // defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on @@ -39,35 +36,12 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http. if rt != nil { return rt } - // Parse the proxy URL to determine the scheme. - proxyURL, errParse := url.Parse(proxyStr) - if errParse != nil { - log.Errorf("parse proxy URL failed: %v", errParse) + transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr) + if errBuild != nil { + log.Errorf("%v", errBuild) return nil } - var transport *http.Transport - // Handle different proxy schemes. - if proxyURL.Scheme == "socks5" { - // Configure SOCKS5 proxy with optional authentication. - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - proxyAuth := &proxy.Auth{User: username, Password: password} - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) - return nil - } - // Set up a custom transport using the SOCKS5 dialer. - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy. - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } else { - log.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) + if transport == nil { return nil } p.mu.Lock() diff --git a/sdk/cliproxy/rtprovider_test.go b/sdk/cliproxy/rtprovider_test.go new file mode 100644 index 0000000000..f907081e29 --- /dev/null +++ b/sdk/cliproxy/rtprovider_test.go @@ -0,0 +1,22 @@ +package cliproxy + +import ( + "net/http" + "testing" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestRoundTripperForDirectBypassesProxy(t *testing.T) { + t.Parallel() + + provider := newDefaultRoundTripperProvider() + rt := provider.RoundTripperFor(&coreauth.Auth{ProxyURL: "direct"}) + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", rt) + } + if transport.Proxy != nil { + t.Fatal("expected direct transport to disable proxy function") + } +} diff --git a/sdk/proxyutil/proxy.go b/sdk/proxyutil/proxy.go new file mode 100644 index 0000000000..591ec9d9c0 --- /dev/null +++ b/sdk/proxyutil/proxy.go @@ -0,0 +1,139 @@ +package proxyutil + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/proxy" +) + +// Mode describes how a proxy setting should be interpreted. +type Mode int + +const ( + // ModeInherit means no explicit proxy behavior was configured. + ModeInherit Mode = iota + // ModeDirect means outbound requests must bypass proxies explicitly. + ModeDirect + // ModeProxy means a concrete proxy URL was configured. + ModeProxy + // ModeInvalid means the proxy setting is present but malformed or unsupported. + ModeInvalid +) + +// Setting is the normalized interpretation of a proxy configuration value. +type Setting struct { + Raw string + Mode Mode + URL *url.URL +} + +// Parse normalizes a proxy configuration value into inherit, direct, or proxy modes. +func Parse(raw string) (Setting, error) { + trimmed := strings.TrimSpace(raw) + setting := Setting{Raw: trimmed} + + if trimmed == "" { + setting.Mode = ModeInherit + return setting, nil + } + + if strings.EqualFold(trimmed, "direct") || strings.EqualFold(trimmed, "none") { + setting.Mode = ModeDirect + return setting, nil + } + + parsedURL, errParse := url.Parse(trimmed) + if errParse != nil { + setting.Mode = ModeInvalid + return setting, fmt.Errorf("parse proxy URL failed: %w", errParse) + } + if parsedURL.Scheme == "" || parsedURL.Host == "" { + setting.Mode = ModeInvalid + return setting, fmt.Errorf("proxy URL missing scheme/host") + } + + switch parsedURL.Scheme { + case "socks5", "http", "https": + setting.Mode = ModeProxy + setting.URL = parsedURL + return setting, nil + default: + setting.Mode = ModeInvalid + return setting, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) + } +} + +// NewDirectTransport returns a transport that bypasses environment proxies. +func NewDirectTransport() *http.Transport { + if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil { + clone := transport.Clone() + clone.Proxy = nil + return clone + } + return &http.Transport{Proxy: nil} +} + +// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting. +func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) { + setting, errParse := Parse(raw) + if errParse != nil { + return nil, setting.Mode, errParse + } + + switch setting.Mode { + case ModeInherit: + return nil, setting.Mode, nil + case ModeDirect: + return NewDirectTransport(), setting.Mode, nil + case ModeProxy: + if setting.URL.Scheme == "socks5" { + var proxyAuth *proxy.Auth + if setting.URL.User != nil { + username := setting.URL.User.Username() + password, _ := setting.URL.User.Password() + proxyAuth = &proxy.Auth{User: username, Password: password} + } + dialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct) + if errSOCKS5 != nil { + return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) + } + return &http.Transport{ + Proxy: nil, + DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, + }, setting.Mode, nil + } + return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil + default: + return nil, setting.Mode, nil + } +} + +// BuildDialer constructs a proxy dialer for settings that operate at the connection layer. +func BuildDialer(raw string) (proxy.Dialer, Mode, error) { + setting, errParse := Parse(raw) + if errParse != nil { + return nil, setting.Mode, errParse + } + + switch setting.Mode { + case ModeInherit: + return nil, setting.Mode, nil + case ModeDirect: + return proxy.Direct, setting.Mode, nil + case ModeProxy: + dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct) + if errDialer != nil { + return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer) + } + return dialer, setting.Mode, nil + default: + return nil, setting.Mode, nil + } +} diff --git a/sdk/proxyutil/proxy_test.go b/sdk/proxyutil/proxy_test.go new file mode 100644 index 0000000000..bea413dc24 --- /dev/null +++ b/sdk/proxyutil/proxy_test.go @@ -0,0 +1,89 @@ +package proxyutil + +import ( + "net/http" + "testing" +) + +func TestParse(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want Mode + wantErr bool + }{ + {name: "inherit", input: "", want: ModeInherit}, + {name: "direct", input: "direct", want: ModeDirect}, + {name: "none", input: "none", want: ModeDirect}, + {name: "http", input: "http://proxy.example.com:8080", want: ModeProxy}, + {name: "https", input: "https://proxy.example.com:8443", want: ModeProxy}, + {name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy}, + {name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + setting, errParse := Parse(tt.input) + if tt.wantErr && errParse == nil { + t.Fatal("expected error, got nil") + } + if !tt.wantErr && errParse != nil { + t.Fatalf("unexpected error: %v", errParse) + } + if setting.Mode != tt.want { + t.Fatalf("mode = %d, want %d", setting.Mode, tt.want) + } + }) + } +} + +func TestBuildHTTPTransportDirectBypassesProxy(t *testing.T) { + t.Parallel() + + transport, mode, errBuild := BuildHTTPTransport("direct") + if errBuild != nil { + t.Fatalf("BuildHTTPTransport returned error: %v", errBuild) + } + if mode != ModeDirect { + t.Fatalf("mode = %d, want %d", mode, ModeDirect) + } + if transport == nil { + t.Fatal("expected transport, got nil") + } + if transport.Proxy != nil { + t.Fatal("expected direct transport to disable proxy function") + } +} + +func TestBuildHTTPTransportHTTPProxy(t *testing.T) { + t.Parallel() + + transport, mode, errBuild := BuildHTTPTransport("http://proxy.example.com:8080") + if errBuild != nil { + t.Fatalf("BuildHTTPTransport returned error: %v", errBuild) + } + if mode != ModeProxy { + t.Fatalf("mode = %d, want %d", mode, ModeProxy) + } + if transport == nil { + t.Fatal("expected transport, got nil") + } + + req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errRequest != nil { + t.Fatalf("http.NewRequest returned error: %v", errRequest) + } + + proxyURL, errProxy := transport.Proxy(req) + if errProxy != nil { + t.Fatalf("transport.Proxy returned error: %v", errProxy) + } + if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" { + t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL) + } +} From 70988d387b232b086a79cfce1e16f599238c6ce3 Mon Sep 17 00:00:00 2001 From: lang-911 Date: Wed, 11 Mar 2026 00:34:57 -0700 Subject: [PATCH 03/13] Add Codex websocket header defaults --- config.example.yaml | 8 + .../codex_websocket_header_defaults_test.go | 32 ++++ internal/config/config.go | 25 +++ internal/runtime/executor/codex_executor.go | 11 +- .../executor/codex_websockets_executor.go | 67 +++++++- .../codex_websockets_executor_test.go | 155 +++++++++++++++++- 6 files changed, 287 insertions(+), 11 deletions(-) create mode 100644 internal/config/codex_websocket_header_defaults_test.go diff --git a/config.example.yaml b/config.example.yaml index 40bb87210a..16be5c3698 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -173,6 +173,14 @@ nonstream-keepalive-interval: 0 # runtime-version: "v24.3.0" # timeout: "600" +# Default headers for Codex OAuth model requests. +# These are used only for file-backed/OAuth Codex requests when the client +# does not send the header. `user-agent` applies to HTTP and websocket requests; +# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries. +# codex-header-defaults: +# user-agent: "my-codex-client/1.0" +# beta-features: "feature-a,feature-b" + # OpenAI compatibility providers # openai-compatibility: # - name: "openrouter" # The name of the provider; it will be used in the user agent and other places. diff --git a/internal/config/codex_websocket_header_defaults_test.go b/internal/config/codex_websocket_header_defaults_test.go new file mode 100644 index 0000000000..49947c1cf6 --- /dev/null +++ b/internal/config/codex_websocket_header_defaults_test.go @@ -0,0 +1,32 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.yaml") + configYAML := []byte(` +codex-header-defaults: + user-agent: " my-codex-client/1.0 " + beta-features: " feature-a,feature-b " +`) + if err := os.WriteFile(configPath, configYAML, 0o600); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + cfg, err := LoadConfigOptional(configPath, false) + if err != nil { + t.Fatalf("LoadConfigOptional() error = %v", err) + } + + if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" { + t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0") + } + if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" { + t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 5a6595f778..7bd137e0db 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -90,6 +90,10 @@ type Config struct { // Codex defines a list of Codex API key configurations as specified in the YAML configuration file. CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"` + // CodexHeaderDefaults configures fallback headers for Codex OAuth model requests. + // These are used only when the client does not send its own headers. + CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"` + // ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file. ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"` @@ -133,6 +137,14 @@ type ClaudeHeaderDefaults struct { Timeout string `yaml:"timeout" json:"timeout"` } +// CodexHeaderDefaults configures fallback header values injected into Codex +// model requests for OAuth/file-backed auth when the client omits them. +// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets. +type CodexHeaderDefaults struct { + UserAgent string `yaml:"user-agent" json:"user-agent"` + BetaFeatures string `yaml:"beta-features" json:"beta-features"` +} + // TLSConfig holds HTTPS server settings. type TLSConfig struct { // Enable toggles HTTPS server mode. @@ -615,6 +627,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Codex keys: drop entries without base-url cfg.SanitizeCodexKeys() + // Sanitize Codex header defaults. + cfg.SanitizeCodexHeaderDefaults() + // Sanitize Claude key headers cfg.SanitizeClaudeKeys() @@ -704,6 +719,16 @@ func payloadRawString(value any) ([]byte, bool) { } } +// SanitizeCodexHeaderDefaults trims surrounding whitespace from the +// configured Codex header fallback values. +func (cfg *Config) SanitizeCodexHeaderDefaults() { + if cfg == nil { + return + } + cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent) + cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures) +} + // SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases. // It trims whitespace, normalizes channel keys to lower-case, drops empty entries, // allows multiple aliases per upstream name, and ensures aliases are unique within each channel. diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 30092ec737..4fb2291900 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -122,7 +122,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re if err != nil { return resp, err } - applyCodexHeaders(httpReq, auth, apiKey, true) + applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg) var authID, authLabel, authType, authValue string if auth != nil { authID = auth.ID @@ -226,7 +226,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A if err != nil { return resp, err } - applyCodexHeaders(httpReq, auth, apiKey, false) + applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg) var authID, authLabel, authType, authValue string if auth != nil { authID = auth.ID @@ -321,7 +321,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au if err != nil { return nil, err } - applyCodexHeaders(httpReq, auth, apiKey, true) + applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg) var authID, authLabel, authType, authValue string if auth != nil { authID = auth.ID @@ -636,7 +636,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form return httpReq, nil } -func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool) { +func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) { r.Header.Set("Content-Type", "application/json") r.Header.Set("Authorization", "Bearer "+token) @@ -647,7 +647,8 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion) misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString()) - misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent) + cfgUserAgent, _ := codexHeaderDefaults(cfg, auth) + ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent) if stream { r.Header.Set("Accept", "text/event-stream") diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 1f3400500c..2a4f4a3ff2 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -190,7 +190,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut } body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body) - wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey) + wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg) var authID, authLabel, authType, authValue string if auth != nil { @@ -385,7 +385,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr } body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body) - wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey) + wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg) var authID, authLabel, authType, authValue string authID = auth.ID @@ -787,7 +787,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto return rawJSON, headers } -func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string) http.Header { +func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header { if headers == nil { headers = http.Header{} } @@ -800,7 +800,8 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth * ginHeaders = ginCtx.Request.Header } - misc.EnsureHeader(headers, ginHeaders, "x-codex-beta-features", "") + cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth) + ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "") misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "") misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "") misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "") @@ -815,7 +816,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth * } headers.Set("OpenAI-Beta", betaHeader) misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString()) - misc.EnsureHeader(headers, ginHeaders, "User-Agent", codexUserAgent) + ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent) isAPIKey := false if auth != nil && auth.Attributes != nil { @@ -843,6 +844,62 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth * return headers } +func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) { + if cfg == nil || auth == nil { + return "", "" + } + if auth.Attributes != nil { + if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" { + return "", "" + } + } + return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures) +} + +func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) { + if target == nil { + return + } + if strings.TrimSpace(target.Get(key)) != "" { + return + } + if source != nil { + if val := strings.TrimSpace(source.Get(key)); val != "" { + target.Set(key, val) + return + } + } + if val := strings.TrimSpace(configValue); val != "" { + target.Set(key, val) + return + } + if val := strings.TrimSpace(fallbackValue); val != "" { + target.Set(key, val) + } +} + +func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) { + if target == nil { + return + } + if strings.TrimSpace(target.Get(key)) != "" { + return + } + if val := strings.TrimSpace(configValue); val != "" { + target.Set(key, val) + return + } + if source != nil { + if val := strings.TrimSpace(source.Get(key)); val != "" { + target.Set(key, val) + return + } + } + if val := strings.TrimSpace(fallbackValue); val != "" { + target.Set(key, val) + } +} + type statusErrWithHeaders struct { statusErr headers http.Header diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index 1fd685138c..e1335386ed 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -3,8 +3,12 @@ package executor import ( "context" "net/http" + "net/http/httptest" "testing" + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/tidwall/gjson" ) @@ -28,9 +32,158 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) } func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) { - headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "") + headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil) if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue { t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue) } + if got := headers.Get("User-Agent"); got != codexUserAgent { + t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent) + } + if got := headers.Get("x-codex-beta-features"); got != "" { + t.Fatalf("x-codex-beta-features = %q, want empty", got) + } +} + +func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) { + cfg := &config.Config{ + CodexHeaderDefaults: config.CodexHeaderDefaults{ + UserAgent: "my-codex-client/1.0", + BetaFeatures: "feature-a,feature-b", + }, + } + auth := &cliproxyauth.Auth{ + Provider: "codex", + Metadata: map[string]any{"email": "user@example.com"}, + } + + headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg) + + if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" { + t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0") + } + if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" { + t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b") + } + if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue { + t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue) + } +} + +func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *testing.T) { + cfg := &config.Config{ + CodexHeaderDefaults: config.CodexHeaderDefaults{ + UserAgent: "config-ua", + BetaFeatures: "config-beta", + }, + } + auth := &cliproxyauth.Auth{ + Provider: "codex", + Metadata: map[string]any{"email": "user@example.com"}, + } + ctx := contextWithGinHeaders(map[string]string{ + "User-Agent": "client-ua", + "X-Codex-Beta-Features": "client-beta", + }) + headers := http.Header{} + headers.Set("User-Agent", "existing-ua") + headers.Set("X-Codex-Beta-Features", "existing-beta") + + got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg) + + if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" { + t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua") + } + if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" { + t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta") + } +} + +func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testing.T) { + cfg := &config.Config{ + CodexHeaderDefaults: config.CodexHeaderDefaults{ + UserAgent: "config-ua", + BetaFeatures: "config-beta", + }, + } + auth := &cliproxyauth.Auth{ + Provider: "codex", + Metadata: map[string]any{"email": "user@example.com"}, + } + ctx := contextWithGinHeaders(map[string]string{ + "User-Agent": "client-ua", + "X-Codex-Beta-Features": "client-beta", + }) + + headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg) + + if got := headers.Get("User-Agent"); got != "config-ua" { + t.Fatalf("User-Agent = %s, want %s", got, "config-ua") + } + if got := headers.Get("x-codex-beta-features"); got != "client-beta" { + t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta") + } +} + +func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) { + cfg := &config.Config{ + CodexHeaderDefaults: config.CodexHeaderDefaults{ + UserAgent: "config-ua", + BetaFeatures: "config-beta", + }, + } + auth := &cliproxyauth.Auth{ + Provider: "codex", + Attributes: map[string]string{"api_key": "sk-test"}, + } + + headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg) + + if got := headers.Get("User-Agent"); got != codexUserAgent { + t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent) + } + if got := headers.Get("x-codex-beta-features"); got != "" { + t.Fatalf("x-codex-beta-features = %q, want empty", got) + } +} + +func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil) + if err != nil { + t.Fatalf("NewRequest() error = %v", err) + } + cfg := &config.Config{ + CodexHeaderDefaults: config.CodexHeaderDefaults{ + UserAgent: "config-ua", + BetaFeatures: "config-beta", + }, + } + auth := &cliproxyauth.Auth{ + Provider: "codex", + Metadata: map[string]any{"email": "user@example.com"}, + } + req = req.WithContext(contextWithGinHeaders(map[string]string{ + "User-Agent": "client-ua", + })) + + applyCodexHeaders(req, auth, "oauth-token", true, cfg) + + if got := req.Header.Get("User-Agent"); got != "config-ua" { + t.Fatalf("User-Agent = %s, want %s", got, "config-ua") + } + if got := req.Header.Get("x-codex-beta-features"); got != "" { + t.Fatalf("x-codex-beta-features = %q, want empty", got) + } +} + +func contextWithGinHeaders(headers map[string]string) context.Context { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequest(http.MethodPost, "/", nil) + ginCtx.Request.Header = make(http.Header, len(headers)) + for key, value := range headers { + ginCtx.Request.Header.Set(key, value) + } + return context.WithValue(context.Background(), "gin", ginCtx) } From 163fe287ce0096c5e626e03ceba8cac2d1cdebc1 Mon Sep 17 00:00:00 2001 From: lang-911 Date: Wed, 11 Mar 2026 06:55:03 -0700 Subject: [PATCH 04/13] fix: codex header defaults example --- config.example.yaml | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 16be5c3698..43f063c4e6 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,6 +1,6 @@ # Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6). # Use "127.0.0.1" or "localhost" to restrict access to local machine only. -host: "" +host: '' # Server port port: 8317 @@ -8,8 +8,8 @@ port: 8317 # TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key. tls: enable: false - cert: "" - key: "" + cert: '' + key: '' # Management API settings remote-management: @@ -20,22 +20,22 @@ remote-management: # Management key. If a plaintext value is provided here, it will be hashed on startup. # All management requests (even from localhost) require this key. # Leave empty to disable the Management API entirely (404 for all /v0/management routes). - secret-key: "" + secret-key: '' # Disable the bundled management control panel asset download and HTTP route when true. disable-control-panel: false # GitHub repository for the management control panel. Accepts a repository URL or releases API URL. - panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center" + panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center' # Authentication directory (supports ~ for home directory) -auth-dir: "~/.cli-proxy-api" +auth-dir: '~/.cli-proxy-api' # API keys for authentication api-keys: - - "your-api-key-1" - - "your-api-key-2" - - "your-api-key-3" + - 'your-api-key-1' + - 'your-api-key-2' + - 'your-api-key-3' # Enable debug logging debug: false @@ -43,7 +43,7 @@ debug: false # Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety. pprof: enable: false - addr: "127.0.0.1:8316" + addr: '127.0.0.1:8316' # When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency. commercial-mode: false @@ -63,7 +63,7 @@ error-logs-max-files: 10 usage-statistics-enabled: false # Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/ -proxy-url: "" +proxy-url: '' # When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name). force-model-prefix: false @@ -89,7 +89,7 @@ quota-exceeded: # Routing strategy for selecting credentials when multiple match. routing: - strategy: "round-robin" # round-robin (default), fill-first + strategy: 'round-robin' # round-robin (default), fill-first # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false @@ -178,8 +178,8 @@ nonstream-keepalive-interval: 0 # does not send the header. `user-agent` applies to HTTP and websocket requests; # `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries. # codex-header-defaults: -# user-agent: "my-codex-client/1.0" -# beta-features: "feature-a,feature-b" +# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0" +# beta-features: "multi_agent" # OpenAI compatibility providers # openai-compatibility: From 2b79d7f22fcf7d797e11375c31d09aa8fcf352b1 Mon Sep 17 00:00:00 2001 From: lang-911 Date: Wed, 11 Mar 2026 06:59:26 -0700 Subject: [PATCH 05/13] fix: restore double quotes style in config.example.yaml for consistency and readability --- config.example.yaml | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 43f063c4e6..4297eb15ac 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,6 +1,6 @@ # Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6). # Use "127.0.0.1" or "localhost" to restrict access to local machine only. -host: '' +host: "" # Server port port: 8317 @@ -8,8 +8,8 @@ port: 8317 # TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key. tls: enable: false - cert: '' - key: '' + cert: "" + key: "" # Management API settings remote-management: @@ -20,22 +20,22 @@ remote-management: # Management key. If a plaintext value is provided here, it will be hashed on startup. # All management requests (even from localhost) require this key. # Leave empty to disable the Management API entirely (404 for all /v0/management routes). - secret-key: '' + secret-key: "" # Disable the bundled management control panel asset download and HTTP route when true. disable-control-panel: false # GitHub repository for the management control panel. Accepts a repository URL or releases API URL. - panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center' + panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center" # Authentication directory (supports ~ for home directory) -auth-dir: '~/.cli-proxy-api' +auth-dir: "~/.cli-proxy-api" # API keys for authentication api-keys: - - 'your-api-key-1' - - 'your-api-key-2' - - 'your-api-key-3' + - "your-api-key-1" + - "your-api-key-2" + - "your-api-key-3" # Enable debug logging debug: false @@ -43,7 +43,7 @@ debug: false # Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety. pprof: enable: false - addr: '127.0.0.1:8316' + addr: "127.0.0.1:8316" # When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency. commercial-mode: false @@ -63,7 +63,7 @@ error-logs-max-files: 10 usage-statistics-enabled: false # Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/ -proxy-url: '' +proxy-url: "" # When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name). force-model-prefix: false @@ -89,7 +89,7 @@ quota-exceeded: # Routing strategy for selecting credentials when multiple match. routing: - strategy: 'round-robin' # round-robin (default), fill-first + strategy: "round-robin" # round-robin (default), fill-first # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false From 861537c9bd77fb3016578b78ad1216ad83741109 Mon Sep 17 00:00:00 2001 From: Aikins Laryea Date: Thu, 12 Mar 2026 00:00:38 +0000 Subject: [PATCH 06/13] fix: backfill empty functionResponse.name from preceding functionCall when Amp or Claude Code sends functionResponse with an empty name in Gemini conversation history, the Gemini API rejects the request with 400 "Name cannot be empty". this fix backfills empty names from the corresponding preceding functionCall parts using positional matching. covers all three Gemini translator paths: - gemini/gemini (direct API key) - antigravity/gemini (OAuth) - gemini-cli/gemini (Gemini CLI) also switches fixCLIToolResponse pending group matching from LIFO to FIFO to correctly handle multiple sequential tool call groups. fixes #1903 --- .../gemini/antigravity_gemini_request.go | 81 +++--- .../gemini/antigravity_gemini_request_test.go | 254 ++++++++++++++++++ .../gemini/gemini-cli_gemini_request.go | 68 +++-- .../gemini/gemini/gemini_gemini_request.go | 67 +++++ .../gemini/gemini_gemini_request_test.go | 193 +++++++++++++ 5 files changed, 604 insertions(+), 59 deletions(-) create mode 100644 internal/translator/gemini/gemini/gemini_gemini_request_test.go diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request.go b/internal/translator/antigravity/gemini/antigravity_gemini_request.go index 1d04474069..2c8ff402c7 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_request.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request.go @@ -138,20 +138,31 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ // FunctionCallGroup represents a group of function calls and their responses type FunctionCallGroup struct { ResponsesNeeded int + CallNames []string // ordered function call names for backfilling empty response names } // parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string. // Falls back to a minimal "functionResponse" object when parsing fails. -func parseFunctionResponseRaw(response gjson.Result) string { +// fallbackName is used when the response's own name is empty. +func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string { if response.IsObject() && gjson.Valid(response.Raw) { - return response.Raw + raw := response.Raw + name := response.Get("functionResponse.name").String() + if strings.TrimSpace(name) == "" && fallbackName != "" { + raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName) + } + return raw } log.Debugf("parse function response failed, using fallback") funcResp := response.Get("functionResponse") if funcResp.Exists() { fr := `{"functionResponse":{"name":"","response":{"result":""}}}` - fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String()) + name := funcResp.Get("name").String() + if strings.TrimSpace(name) == "" { + name = fallbackName + } + fr, _ = sjson.Set(fr, "functionResponse.name", name) fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String()) if id := funcResp.Get("id").String(); id != "" { fr, _ = sjson.Set(fr, "functionResponse.id", id) @@ -159,7 +170,12 @@ func parseFunctionResponseRaw(response gjson.Result) string { return fr } - fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}` + useName := fallbackName + if useName == "" { + useName = "unknown" + } + fr := `{"functionResponse":{"name":"","response":{"result":""}}}` + fr, _ = sjson.Set(fr, "functionResponse.name", useName) fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String()) return fr } @@ -211,30 +227,26 @@ func fixCLIToolResponse(input string) (string, error) { if len(responsePartsInThisContent) > 0 { collectedResponses = append(collectedResponses, responsePartsInThisContent...) - // Check if any pending groups can be satisfied - for i := len(pendingGroups) - 1; i >= 0; i-- { - group := pendingGroups[i] - if len(collectedResponses) >= group.ResponsesNeeded { - // Take the needed responses for this group - groupResponses := collectedResponses[:group.ResponsesNeeded] - collectedResponses = collectedResponses[group.ResponsesNeeded:] - - // Create merged function response content - functionResponseContent := `{"parts":[],"role":"function"}` - for _, response := range groupResponses { - partRaw := parseFunctionResponseRaw(response) - if partRaw != "" { - functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw) - } - } - - if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { - contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) + // Check if pending groups can be satisfied (FIFO: oldest group first) + for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded { + group := pendingGroups[0] + pendingGroups = pendingGroups[1:] + + // Take the needed responses for this group + groupResponses := collectedResponses[:group.ResponsesNeeded] + collectedResponses = collectedResponses[group.ResponsesNeeded:] + + // Create merged function response content + functionResponseContent := `{"parts":[],"role":"function"}` + for ri, response := range groupResponses { + partRaw := parseFunctionResponseRaw(response, group.CallNames[ri]) + if partRaw != "" { + functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw) } + } - // Remove this group as it's been satisfied - pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...) - break + if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { + contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) } } @@ -243,15 +255,15 @@ func fixCLIToolResponse(input string) (string, error) { // If this is a model with function calls, create a new group if role == "model" { - functionCallsCount := 0 + var callNames []string parts.ForEach(func(_, part gjson.Result) bool { if part.Get("functionCall").Exists() { - functionCallsCount++ + callNames = append(callNames, part.Get("functionCall.name").String()) } return true }) - if functionCallsCount > 0 { + if len(callNames) > 0 { // Add the model content if !value.IsObject() { log.Warnf("failed to parse model content") @@ -261,7 +273,8 @@ func fixCLIToolResponse(input string) (string, error) { // Create a new group for tracking responses group := &FunctionCallGroup{ - ResponsesNeeded: functionCallsCount, + ResponsesNeeded: len(callNames), + CallNames: callNames, } pendingGroups = append(pendingGroups, group) } else { @@ -291,8 +304,12 @@ func fixCLIToolResponse(input string) (string, error) { collectedResponses = collectedResponses[group.ResponsesNeeded:] functionResponseContent := `{"parts":[],"role":"function"}` - for _, response := range groupResponses { - partRaw := parseFunctionResponseRaw(response) + for ri, response := range groupResponses { + fallbackName := "" + if ri < len(group.CallNames) { + fallbackName = group.CallNames[ri] + } + partRaw := parseFunctionResponseRaw(response, fallbackName) if partRaw != "" { functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw) } diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go b/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go index da581d1a3c..7e9e3bba8b 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go @@ -171,3 +171,257 @@ func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) { t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String()) } } + +func TestFixCLIToolResponse_BackfillsEmptyFunctionResponseName(t *testing.T) { + // When the Amp client sends functionResponse with an empty name, + // fixCLIToolResponse should backfill it from the corresponding functionCall. + input := `{ + "model": "gemini-3-pro-preview", + "request": { + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}} + ] + }, + { + "role": "function", + "parts": [ + {"functionResponse": {"name": "", "response": {"output": "file1.txt"}}} + ] + } + ] + } + }` + + result, err := fixCLIToolResponse(input) + if err != nil { + t.Fatalf("fixCLIToolResponse failed: %v", err) + } + + contents := gjson.Get(result, "request.contents").Array() + var funcContent gjson.Result + for _, c := range contents { + if c.Get("role").String() == "function" { + funcContent = c + break + } + } + if !funcContent.Exists() { + t.Fatal("function role content should exist in output") + } + + name := funcContent.Get("parts.0.functionResponse.name").String() + if name != "Bash" { + t.Errorf("Expected backfilled name 'Bash', got '%s'", name) + } +} + +func TestFixCLIToolResponse_BackfillsMultipleEmptyNames(t *testing.T) { + // Parallel function calls: both responses have empty names. + input := `{ + "model": "gemini-3-pro-preview", + "request": { + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Read", "args": {"path": "/a"}}}, + {"functionCall": {"name": "Grep", "args": {"pattern": "x"}}} + ] + }, + { + "role": "function", + "parts": [ + {"functionResponse": {"name": "", "response": {"result": "content a"}}}, + {"functionResponse": {"name": "", "response": {"result": "match x"}}} + ] + } + ] + } + }` + + result, err := fixCLIToolResponse(input) + if err != nil { + t.Fatalf("fixCLIToolResponse failed: %v", err) + } + + contents := gjson.Get(result, "request.contents").Array() + var funcContent gjson.Result + for _, c := range contents { + if c.Get("role").String() == "function" { + funcContent = c + break + } + } + if !funcContent.Exists() { + t.Fatal("function role content should exist in output") + } + + parts := funcContent.Get("parts").Array() + if len(parts) != 2 { + t.Fatalf("Expected 2 function response parts, got %d", len(parts)) + } + + name0 := parts[0].Get("functionResponse.name").String() + name1 := parts[1].Get("functionResponse.name").String() + if name0 != "Read" { + t.Errorf("Expected first response name 'Read', got '%s'", name0) + } + if name1 != "Grep" { + t.Errorf("Expected second response name 'Grep', got '%s'", name1) + } +} + +func TestFixCLIToolResponse_PreservesExistingName(t *testing.T) { + // When functionResponse already has a valid name, it should be preserved. + input := `{ + "model": "gemini-3-pro-preview", + "request": { + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Bash", "args": {}}} + ] + }, + { + "role": "function", + "parts": [ + {"functionResponse": {"name": "Bash", "response": {"result": "ok"}}} + ] + } + ] + } + }` + + result, err := fixCLIToolResponse(input) + if err != nil { + t.Fatalf("fixCLIToolResponse failed: %v", err) + } + + contents := gjson.Get(result, "request.contents").Array() + var funcContent gjson.Result + for _, c := range contents { + if c.Get("role").String() == "function" { + funcContent = c + break + } + } + if !funcContent.Exists() { + t.Fatal("function role content should exist in output") + } + + name := funcContent.Get("parts.0.functionResponse.name").String() + if name != "Bash" { + t.Errorf("Expected preserved name 'Bash', got '%s'", name) + } +} + +func TestFixCLIToolResponse_MoreResponsesThanCalls(t *testing.T) { + // If there are more function responses than calls, unmatched extras are discarded by grouping. + input := `{ + "model": "gemini-3-pro-preview", + "request": { + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Bash", "args": {}}} + ] + }, + { + "role": "function", + "parts": [ + {"functionResponse": {"name": "", "response": {"result": "ok"}}}, + {"functionResponse": {"name": "", "response": {"result": "extra"}}} + ] + } + ] + } + }` + + result, err := fixCLIToolResponse(input) + if err != nil { + t.Fatalf("fixCLIToolResponse failed: %v", err) + } + + contents := gjson.Get(result, "request.contents").Array() + var funcContent gjson.Result + for _, c := range contents { + if c.Get("role").String() == "function" { + funcContent = c + break + } + } + if !funcContent.Exists() { + t.Fatal("function role content should exist in output") + } + + // First response should be backfilled from the call + name0 := funcContent.Get("parts.0.functionResponse.name").String() + if name0 != "Bash" { + t.Errorf("Expected first response name 'Bash', got '%s'", name0) + } +} + +func TestFixCLIToolResponse_MultipleGroupsFIFO(t *testing.T) { + // Two sequential function call groups should be matched FIFO. + input := `{ + "model": "gemini-3-pro-preview", + "request": { + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Read", "args": {}}} + ] + }, + { + "role": "function", + "parts": [ + {"functionResponse": {"name": "", "response": {"result": "file content"}}} + ] + }, + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Grep", "args": {}}} + ] + }, + { + "role": "function", + "parts": [ + {"functionResponse": {"name": "", "response": {"result": "match"}}} + ] + } + ] + } + }` + + result, err := fixCLIToolResponse(input) + if err != nil { + t.Fatalf("fixCLIToolResponse failed: %v", err) + } + + contents := gjson.Get(result, "request.contents").Array() + var funcContents []gjson.Result + for _, c := range contents { + if c.Get("role").String() == "function" { + funcContents = append(funcContents, c) + } + } + if len(funcContents) != 2 { + t.Fatalf("Expected 2 function contents, got %d", len(funcContents)) + } + + name0 := funcContents[0].Get("parts.0.functionResponse.name").String() + name1 := funcContents[1].Get("parts.0.functionResponse.name").String() + if name0 != "Read" { + t.Errorf("Expected first group name 'Read', got '%s'", name0) + } + if name1 != "Grep" { + t.Errorf("Expected second group name 'Grep', got '%s'", name1) + } +} diff --git a/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go b/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go index 15ff8b983a..c60390886d 100644 --- a/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go +++ b/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go @@ -7,6 +7,7 @@ package gemini import ( "fmt" + "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" @@ -116,6 +117,17 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by // FunctionCallGroup represents a group of function calls and their responses type FunctionCallGroup struct { ResponsesNeeded int + CallNames []string // ordered function call names for backfilling empty response names +} + +// backfillFunctionResponseName ensures that a functionResponse JSON object has a non-empty name, +// falling back to fallbackName if the original is empty. +func backfillFunctionResponseName(raw string, fallbackName string) string { + name := gjson.Get(raw, "functionResponse.name").String() + if strings.TrimSpace(name) == "" && fallbackName != "" { + raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName) + } + return raw } // fixCLIToolResponse performs sophisticated tool response format conversion and grouping. @@ -165,31 +177,28 @@ func fixCLIToolResponse(input string) (string, error) { if len(responsePartsInThisContent) > 0 { collectedResponses = append(collectedResponses, responsePartsInThisContent...) - // Check if any pending groups can be satisfied - for i := len(pendingGroups) - 1; i >= 0; i-- { - group := pendingGroups[i] - if len(collectedResponses) >= group.ResponsesNeeded { - // Take the needed responses for this group - groupResponses := collectedResponses[:group.ResponsesNeeded] - collectedResponses = collectedResponses[group.ResponsesNeeded:] + // Check if pending groups can be satisfied (FIFO: oldest group first) + for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded { + group := pendingGroups[0] + pendingGroups = pendingGroups[1:] - // Create merged function response content - functionResponseContent := `{"parts":[],"role":"function"}` - for _, response := range groupResponses { - if !response.IsObject() { - log.Warnf("failed to parse function response") - continue - } - functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw) - } + // Take the needed responses for this group + groupResponses := collectedResponses[:group.ResponsesNeeded] + collectedResponses = collectedResponses[group.ResponsesNeeded:] - if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { - contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) + // Create merged function response content + functionResponseContent := `{"parts":[],"role":"function"}` + for ri, response := range groupResponses { + if !response.IsObject() { + log.Warnf("failed to parse function response") + continue } + raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri]) + functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", raw) + } - // Remove this group as it's been satisfied - pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...) - break + if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { + contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) } } @@ -198,15 +207,15 @@ func fixCLIToolResponse(input string) (string, error) { // If this is a model with function calls, create a new group if role == "model" { - functionCallsCount := 0 + var callNames []string parts.ForEach(func(_, part gjson.Result) bool { if part.Get("functionCall").Exists() { - functionCallsCount++ + callNames = append(callNames, part.Get("functionCall.name").String()) } return true }) - if functionCallsCount > 0 { + if len(callNames) > 0 { // Add the model content if !value.IsObject() { log.Warnf("failed to parse model content") @@ -216,7 +225,8 @@ func fixCLIToolResponse(input string) (string, error) { // Create a new group for tracking responses group := &FunctionCallGroup{ - ResponsesNeeded: functionCallsCount, + ResponsesNeeded: len(callNames), + CallNames: callNames, } pendingGroups = append(pendingGroups, group) } else { @@ -246,12 +256,16 @@ func fixCLIToolResponse(input string) (string, error) { collectedResponses = collectedResponses[group.ResponsesNeeded:] functionResponseContent := `{"parts":[],"role":"function"}` - for _, response := range groupResponses { + for ri, response := range groupResponses { if !response.IsObject() { log.Warnf("failed to parse function response") continue } - functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw) + raw := response.Raw + if ri < len(group.CallNames) { + raw = backfillFunctionResponseName(raw, group.CallNames[ri]) + } + functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", raw) } if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { diff --git a/internal/translator/gemini/gemini/gemini_gemini_request.go b/internal/translator/gemini/gemini/gemini_gemini_request.go index 8024e9e329..abc176b2e2 100644 --- a/internal/translator/gemini/gemini/gemini_gemini_request.go +++ b/internal/translator/gemini/gemini/gemini_gemini_request.go @@ -5,9 +5,11 @@ package gemini import ( "fmt" + "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -95,6 +97,71 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte out = []byte(strJson) } + // Backfill empty functionResponse.name from the preceding functionCall.name. + // Amp may send function responses with empty names; the Gemini API rejects these. + out = backfillEmptyFunctionResponseNames(out) + out = common.AttachDefaultSafetySettings(out, "safetySettings") return out } + +// backfillEmptyFunctionResponseNames walks the contents array and for each +// model turn containing functionCall parts, records the call names in order. +// For the immediately following user/function turn containing functionResponse +// parts, any empty name is replaced with the corresponding call name. +func backfillEmptyFunctionResponseNames(data []byte) []byte { + contents := gjson.GetBytes(data, "contents") + if !contents.Exists() { + return data + } + + out := data + var pendingCallNames []string + + contents.ForEach(func(contentIdx, content gjson.Result) bool { + role := content.Get("role").String() + + // Collect functionCall names from model turns + if role == "model" { + var names []string + content.Get("parts").ForEach(func(_, part gjson.Result) bool { + if part.Get("functionCall").Exists() { + names = append(names, part.Get("functionCall.name").String()) + } + return true + }) + if len(names) > 0 { + pendingCallNames = names + } else { + pendingCallNames = nil + } + return true + } + + // Backfill empty functionResponse names from pending call names + if len(pendingCallNames) > 0 { + ri := 0 + content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool { + if part.Get("functionResponse").Exists() { + name := part.Get("functionResponse.name").String() + if strings.TrimSpace(name) == "" { + if ri < len(pendingCallNames) { + out, _ = sjson.SetBytes(out, + fmt.Sprintf("contents.%d.parts.%d.functionResponse.name", contentIdx.Int(), partIdx.Int()), + pendingCallNames[ri]) + } else { + log.Debugf("more function responses than calls at contents[%d], skipping name backfill", contentIdx.Int()) + } + } + ri++ + } + return true + }) + pendingCallNames = nil + } + + return true + }) + + return out +} diff --git a/internal/translator/gemini/gemini/gemini_gemini_request_test.go b/internal/translator/gemini/gemini/gemini_gemini_request_test.go new file mode 100644 index 0000000000..5eb88fa545 --- /dev/null +++ b/internal/translator/gemini/gemini/gemini_gemini_request_test.go @@ -0,0 +1,193 @@ +package gemini + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestBackfillEmptyFunctionResponseNames_Single(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"name": "", "response": {"output": "file1.txt"}}} + ] + } + ] + }`) + + out := backfillEmptyFunctionResponseNames(input) + + name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String() + if name != "Bash" { + t.Errorf("Expected backfilled name 'Bash', got '%s'", name) + } +} + +func TestBackfillEmptyFunctionResponseNames_Parallel(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Read", "args": {"path": "/a"}}}, + {"functionCall": {"name": "Grep", "args": {"pattern": "x"}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"name": "", "response": {"result": "content a"}}}, + {"functionResponse": {"name": "", "response": {"result": "match x"}}} + ] + } + ] + }`) + + out := backfillEmptyFunctionResponseNames(input) + + name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String() + name1 := gjson.GetBytes(out, "contents.1.parts.1.functionResponse.name").String() + if name0 != "Read" { + t.Errorf("Expected first name 'Read', got '%s'", name0) + } + if name1 != "Grep" { + t.Errorf("Expected second name 'Grep', got '%s'", name1) + } +} + +func TestBackfillEmptyFunctionResponseNames_PreservesExisting(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Bash", "args": {}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"name": "Bash", "response": {"result": "ok"}}} + ] + } + ] + }`) + + out := backfillEmptyFunctionResponseNames(input) + + name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String() + if name != "Bash" { + t.Errorf("Expected preserved name 'Bash', got '%s'", name) + } +} + +func TestConvertGeminiRequestToGemini_BackfillsEmptyName(t *testing.T) { + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"name": "", "response": {"output": "file1.txt"}}} + ] + } + ] + }`) + + out := ConvertGeminiRequestToGemini("", input, false) + + name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String() + if name != "Bash" { + t.Errorf("Expected backfilled name 'Bash', got '%s'", name) + } +} + +func TestBackfillEmptyFunctionResponseNames_MoreResponsesThanCalls(t *testing.T) { + // Extra responses beyond the call count should not panic and should be left unchanged. + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Bash", "args": {}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"name": "", "response": {"result": "ok"}}}, + {"functionResponse": {"name": "", "response": {"result": "extra"}}} + ] + } + ] + }`) + + out := backfillEmptyFunctionResponseNames(input) + + name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String() + if name0 != "Bash" { + t.Errorf("Expected first name 'Bash', got '%s'", name0) + } + // Second response has no matching call, should remain empty + name1 := gjson.GetBytes(out, "contents.1.parts.1.functionResponse.name").String() + if name1 != "" { + t.Errorf("Expected second name to remain empty, got '%s'", name1) + } +} + +func TestBackfillEmptyFunctionResponseNames_MultipleGroups(t *testing.T) { + // Two sequential call/response groups should each get correct names. + input := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Read", "args": {}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"name": "", "response": {"result": "content"}}} + ] + }, + { + "role": "model", + "parts": [ + {"functionCall": {"name": "Grep", "args": {}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"name": "", "response": {"result": "match"}}} + ] + } + ] + }`) + + out := backfillEmptyFunctionResponseNames(input) + + name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String() + name1 := gjson.GetBytes(out, "contents.3.parts.0.functionResponse.name").String() + if name0 != "Read" { + t.Errorf("Expected first group name 'Read', got '%s'", name0) + } + if name1 != "Grep" { + t.Errorf("Expected second group name 'Grep', got '%s'", name1) + } +} From a6c3042e34c95f21633add04d064fb2a7626dd41 Mon Sep 17 00:00:00 2001 From: Aikins Laryea Date: Thu, 12 Mar 2026 00:12:43 +0000 Subject: [PATCH 07/13] refactor: remove redundant bounds checks per code review --- .../antigravity/gemini/antigravity_gemini_request.go | 6 +----- .../gemini-cli/gemini/gemini-cli_gemini_request.go | 5 +---- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request.go b/internal/translator/antigravity/gemini/antigravity_gemini_request.go index 2c8ff402c7..e5ce0c31bb 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_request.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request.go @@ -305,11 +305,7 @@ func fixCLIToolResponse(input string) (string, error) { functionResponseContent := `{"parts":[],"role":"function"}` for ri, response := range groupResponses { - fallbackName := "" - if ri < len(group.CallNames) { - fallbackName = group.CallNames[ri] - } - partRaw := parseFunctionResponseRaw(response, fallbackName) + partRaw := parseFunctionResponseRaw(response, group.CallNames[ri]) if partRaw != "" { functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw) } diff --git a/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go b/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go index c60390886d..a2af6f839b 100644 --- a/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go +++ b/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go @@ -261,10 +261,7 @@ func fixCLIToolResponse(input string) (string, error) { log.Warnf("failed to parse function response") continue } - raw := response.Raw - if ri < len(group.CallNames) { - raw = backfillFunctionResponseName(raw, group.CallNames[ri]) - } + raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri]) functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", raw) } From dea3e74d35a87eb9490dfbf9560d20691495262c Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Thu, 12 Mar 2026 09:24:45 +0800 Subject: [PATCH 08/13] feat(antigravity): refactor model handling and remove unused code --- internal/registry/model_definitions.go | 99 ++------ internal/registry/model_updater.go | 21 +- internal/registry/models/models.json | 139 +++++++++-- .../runtime/executor/antigravity_executor.go | 234 ------------------ .../antigravity_executor_models_cache_test.go | 90 ------- sdk/cliproxy/service.go | 59 +---- .../service_antigravity_backfill_test.go | 135 ---------- 7 files changed, 142 insertions(+), 635 deletions(-) delete mode 100644 internal/runtime/executor/antigravity_executor_models_cache_test.go delete mode 100644 sdk/cliproxy/service_antigravity_backfill_test.go diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index b7f5edb17b..14e2852ea7 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -3,32 +3,24 @@ package registry import ( - "sort" "strings" ) -// AntigravityModelConfig captures static antigravity model overrides, including -// Thinking budget limits and provider max completion tokens. -type AntigravityModelConfig struct { - Thinking *ThinkingSupport `json:"thinking,omitempty"` - MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` -} - // staticModelsJSON mirrors the top-level structure of models.json. type staticModelsJSON struct { - Claude []*ModelInfo `json:"claude"` - Gemini []*ModelInfo `json:"gemini"` - Vertex []*ModelInfo `json:"vertex"` - GeminiCLI []*ModelInfo `json:"gemini-cli"` - AIStudio []*ModelInfo `json:"aistudio"` - CodexFree []*ModelInfo `json:"codex-free"` - CodexTeam []*ModelInfo `json:"codex-team"` - CodexPlus []*ModelInfo `json:"codex-plus"` - CodexPro []*ModelInfo `json:"codex-pro"` - Qwen []*ModelInfo `json:"qwen"` - IFlow []*ModelInfo `json:"iflow"` - Kimi []*ModelInfo `json:"kimi"` - Antigravity map[string]*AntigravityModelConfig `json:"antigravity"` + Claude []*ModelInfo `json:"claude"` + Gemini []*ModelInfo `json:"gemini"` + Vertex []*ModelInfo `json:"vertex"` + GeminiCLI []*ModelInfo `json:"gemini-cli"` + AIStudio []*ModelInfo `json:"aistudio"` + CodexFree []*ModelInfo `json:"codex-free"` + CodexTeam []*ModelInfo `json:"codex-team"` + CodexPlus []*ModelInfo `json:"codex-plus"` + CodexPro []*ModelInfo `json:"codex-pro"` + Qwen []*ModelInfo `json:"qwen"` + IFlow []*ModelInfo `json:"iflow"` + Kimi []*ModelInfo `json:"kimi"` + Antigravity []*ModelInfo `json:"antigravity"` } // GetClaudeModels returns the standard Claude model definitions. @@ -91,33 +83,9 @@ func GetKimiModels() []*ModelInfo { return cloneModelInfos(getModels().Kimi) } -// GetAntigravityModelConfig returns static configuration for antigravity models. -// Keys use upstream model names returned by the Antigravity models endpoint. -func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { - data := getModels() - if len(data.Antigravity) == 0 { - return nil - } - out := make(map[string]*AntigravityModelConfig, len(data.Antigravity)) - for k, v := range data.Antigravity { - out[k] = cloneAntigravityModelConfig(v) - } - return out -} - -func cloneAntigravityModelConfig(cfg *AntigravityModelConfig) *AntigravityModelConfig { - if cfg == nil { - return nil - } - copyConfig := *cfg - if cfg.Thinking != nil { - copyThinking := *cfg.Thinking - if len(cfg.Thinking.Levels) > 0 { - copyThinking.Levels = append([]string(nil), cfg.Thinking.Levels...) - } - copyConfig.Thinking = ©Thinking - } - return ©Config +// GetAntigravityModels returns the standard Antigravity model definitions. +func GetAntigravityModels() []*ModelInfo { + return cloneModelInfos(getModels().Antigravity) } // cloneModelInfos returns a shallow copy of the slice with each element deep-cloned. @@ -145,7 +113,7 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo { // - qwen // - iflow // - kimi -// - antigravity (returns static overrides only) +// - antigravity func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { key := strings.ToLower(strings.TrimSpace(channel)) switch key { @@ -168,28 +136,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { case "kimi": return GetKimiModels() case "antigravity": - cfg := GetAntigravityModelConfig() - if len(cfg) == 0 { - return nil - } - models := make([]*ModelInfo, 0, len(cfg)) - for modelID, entry := range cfg { - if modelID == "" || entry == nil { - continue - } - models = append(models, &ModelInfo{ - ID: modelID, - Object: "model", - OwnedBy: "antigravity", - Type: "antigravity", - Thinking: entry.Thinking, - MaxCompletionTokens: entry.MaxCompletionTokens, - }) - } - sort.Slice(models, func(i, j int) bool { - return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID) - }) - return models + return GetAntigravityModels() default: return nil } @@ -213,6 +160,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { data.Qwen, data.IFlow, data.Kimi, + data.Antigravity, } for _, models := range allModels { for _, m := range models { @@ -222,14 +170,5 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { } } - // Check Antigravity static config - if cfg := cloneAntigravityModelConfig(data.Antigravity[modelID]); cfg != nil { - return &ModelInfo{ - ID: modelID, - Thinking: cfg.Thinking, - MaxCompletionTokens: cfg.MaxCompletionTokens, - } - } - return nil } diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go index 84c9d6aa63..8775ca3598 100644 --- a/internal/registry/model_updater.go +++ b/internal/registry/model_updater.go @@ -145,6 +145,7 @@ func validateModelsCatalog(data *staticModelsJSON) error { {name: "qwen", models: data.Qwen}, {name: "iflow", models: data.IFlow}, {name: "kimi", models: data.Kimi}, + {name: "antigravity", models: data.Antigravity}, } for _, section := range requiredSections { @@ -152,9 +153,6 @@ func validateModelsCatalog(data *staticModelsJSON) error { return err } } - if err := validateAntigravitySection(data.Antigravity); err != nil { - return err - } return nil } @@ -179,20 +177,3 @@ func validateModelSection(section string, models []*ModelInfo) error { } return nil } - -func validateAntigravitySection(configs map[string]*AntigravityModelConfig) error { - if len(configs) == 0 { - return fmt.Errorf("antigravity section is empty") - } - - for modelID, cfg := range configs { - trimmedID := strings.TrimSpace(modelID) - if trimmedID == "" { - return fmt.Errorf("antigravity contains empty model id") - } - if cfg == nil { - return fmt.Errorf("antigravity[%q] is null", trimmedID) - } - } - return nil -} diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index 5f919f9f6c..545b476c9a 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -2481,40 +2481,83 @@ } } ], - "antigravity": { - "claude-opus-4-6-thinking": { + "antigravity": [ + { + "id": "claude-opus-4-6-thinking", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Claude Opus 4.6 (Thinking)", + "name": "claude-opus-4-6-thinking", + "description": "Claude Opus 4.6 (Thinking)", + "context_length": 200000, + "max_completion_tokens": 64000, "thinking": { "min": 1024, "max": 64000, "zero_allowed": true, "dynamic_allowed": true - }, - "max_completion_tokens": 64000 + } }, - "claude-sonnet-4-6": { + { + "id": "claude-sonnet-4-6", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Claude Sonnet 4.6 (Thinking)", + "name": "claude-sonnet-4-6", + "description": "Claude Sonnet 4.6 (Thinking)", + "context_length": 200000, + "max_completion_tokens": 64000, "thinking": { "min": 1024, "max": 64000, "zero_allowed": true, "dynamic_allowed": true - }, - "max_completion_tokens": 64000 + } }, - "gemini-2.5-flash": { + { + "id": "gemini-2.5-flash", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 2.5 Flash", + "name": "gemini-2.5-flash", + "description": "Gemini 2.5 Flash", + "context_length": 1048576, + "max_completion_tokens": 65535, "thinking": { "max": 24576, "zero_allowed": true, "dynamic_allowed": true } }, - "gemini-2.5-flash-lite": { + { + "id": "gemini-2.5-flash-lite", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 2.5 Flash Lite", + "name": "gemini-2.5-flash-lite", + "description": "Gemini 2.5 Flash Lite", + "context_length": 1048576, + "max_completion_tokens": 65535, "thinking": { "max": 24576, "zero_allowed": true, "dynamic_allowed": true } }, - "gemini-3-flash": { + { + "id": "gemini-3-flash", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 3 Flash", + "name": "gemini-3-flash", + "description": "Gemini 3 Flash", + "context_length": 1048576, + "max_completion_tokens": 65536, "thinking": { "min": 128, "max": 32768, @@ -2527,7 +2570,16 @@ ] } }, - "gemini-3-pro-high": { + { + "id": "gemini-3-pro-high", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 3 Pro (High)", + "name": "gemini-3-pro-high", + "description": "Gemini 3 Pro (High)", + "context_length": 1048576, + "max_completion_tokens": 65535, "thinking": { "min": 128, "max": 32768, @@ -2538,7 +2590,16 @@ ] } }, - "gemini-3-pro-low": { + { + "id": "gemini-3-pro-low", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 3 Pro (Low)", + "name": "gemini-3-pro-low", + "description": "Gemini 3 Pro (Low)", + "context_length": 1048576, + "max_completion_tokens": 65535, "thinking": { "min": 128, "max": 32768, @@ -2549,7 +2610,14 @@ ] } }, - "gemini-3.1-flash-image": { + { + "id": "gemini-3.1-flash-image", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 3.1 Flash Image", + "name": "gemini-3.1-flash-image", + "description": "Gemini 3.1 Flash Image", "thinking": { "min": 128, "max": 32768, @@ -2560,7 +2628,14 @@ ] } }, - "gemini-3.1-flash-lite-preview": { + { + "id": "gemini-3.1-flash-lite-preview", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 3.1 Flash Lite Preview", + "name": "gemini-3.1-flash-lite-preview", + "description": "Gemini 3.1 Flash Lite Preview", "thinking": { "min": 128, "max": 32768, @@ -2571,7 +2646,16 @@ ] } }, - "gemini-3.1-pro-high": { + { + "id": "gemini-3.1-pro-high", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 3.1 Pro (High)", + "name": "gemini-3.1-pro-high", + "description": "Gemini 3.1 Pro (High)", + "context_length": 1048576, + "max_completion_tokens": 65535, "thinking": { "min": 128, "max": 32768, @@ -2582,7 +2666,16 @@ ] } }, - "gemini-3.1-pro-low": { + { + "id": "gemini-3.1-pro-low", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "Gemini 3.1 Pro (Low)", + "name": "gemini-3.1-pro-low", + "description": "Gemini 3.1 Pro (Low)", + "context_length": 1048576, + "max_completion_tokens": 65535, "thinking": { "min": 128, "max": 32768, @@ -2593,6 +2686,16 @@ ] } }, - "gpt-oss-120b-medium": {} - } + { + "id": "gpt-oss-120b-medium", + "object": "model", + "owned_by": "antigravity", + "type": "antigravity", + "display_name": "GPT-OSS 120B (Medium)", + "name": "gpt-oss-120b-medium", + "description": "GPT-OSS 120B (Medium)", + "context_length": 114000, + "max_completion_tokens": 32768 + } + ] } \ No newline at end of file diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index f3a052bf0c..cda02d2cea 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -24,7 +24,6 @@ import ( "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" @@ -43,7 +42,6 @@ const ( antigravityCountTokensPath = "/v1internal:countTokens" antigravityStreamPath = "/v1internal:streamGenerateContent" antigravityGeneratePath = "/v1internal:generateContent" - antigravityModelsPath = "/v1internal:fetchAvailableModels" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientSecret = "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" defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64" @@ -55,78 +53,8 @@ const ( var ( randSource = rand.New(rand.NewSource(time.Now().UnixNano())) randSourceMutex sync.Mutex - // antigravityPrimaryModelsCache keeps the latest non-empty model list fetched - // from any antigravity auth. Empty fetches never overwrite this cache. - antigravityPrimaryModelsCache struct { - mu sync.RWMutex - models []*registry.ModelInfo - } ) -func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo { - if len(models) == 0 { - return nil - } - out := make([]*registry.ModelInfo, 0, len(models)) - for _, model := range models { - if model == nil || strings.TrimSpace(model.ID) == "" { - continue - } - out = append(out, cloneAntigravityModelInfo(model)) - } - if len(out) == 0 { - return nil - } - return out -} - -func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo { - if model == nil { - return nil - } - clone := *model - if len(model.SupportedGenerationMethods) > 0 { - clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...) - } - if len(model.SupportedParameters) > 0 { - clone.SupportedParameters = append([]string(nil), model.SupportedParameters...) - } - if model.Thinking != nil { - thinkingClone := *model.Thinking - if len(model.Thinking.Levels) > 0 { - thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...) - } - clone.Thinking = &thinkingClone - } - return &clone -} - -func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool { - cloned := cloneAntigravityModels(models) - if len(cloned) == 0 { - return false - } - antigravityPrimaryModelsCache.mu.Lock() - antigravityPrimaryModelsCache.models = cloned - antigravityPrimaryModelsCache.mu.Unlock() - return true -} - -func loadAntigravityPrimaryModels() []*registry.ModelInfo { - antigravityPrimaryModelsCache.mu.RLock() - cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models) - antigravityPrimaryModelsCache.mu.RUnlock() - return cloned -} - -func fallbackAntigravityPrimaryModels() []*registry.ModelInfo { - models := loadAntigravityPrimaryModels() - if len(models) > 0 { - log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models)) - } - return models -} - // AntigravityExecutor proxies requests to the antigravity upstream. type AntigravityExecutor struct { cfg *config.Config @@ -1150,168 +1078,6 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut } } -// FetchAntigravityModels retrieves available models using the supplied auth. -func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo { - exec := &AntigravityExecutor{cfg: cfg} - token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth) - if errToken != nil || token == "" { - return fallbackAntigravityPrimaryModels() - } - if updatedAuth != nil { - auth = updatedAuth - } - - baseURLs := antigravityBaseURLFallbackOrder(auth) - httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0) - - for idx, baseURL := range baseURLs { - modelsURL := baseURL + antigravityModelsPath - - var payload []byte - if auth != nil && auth.Metadata != nil { - if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" { - payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid))) - } - } - if len(payload) == 0 { - payload = []byte(`{}`) - } - - httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload)) - if errReq != nil { - return fallbackAntigravityPrimaryModels() - } - httpReq.Close = true - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+token) - httpReq.Header.Set("User-Agent", resolveUserAgent(auth)) - if host := resolveHost(baseURL); host != "" { - httpReq.Host = host - } - - httpResp, errDo := httpClient.Do(httpReq) - if errDo != nil { - if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { - return fallbackAntigravityPrimaryModels() - } - if idx+1 < len(baseURLs) { - log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) - continue - } - return fallbackAntigravityPrimaryModels() - } - - bodyBytes, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("antigravity executor: close response body error: %v", errClose) - } - if errRead != nil { - if idx+1 < len(baseURLs) { - log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) - continue - } - return fallbackAntigravityPrimaryModels() - } - if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { - if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) { - log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) - continue - } - if idx+1 < len(baseURLs) { - log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1]) - continue - } - return fallbackAntigravityPrimaryModels() - } - - result := gjson.GetBytes(bodyBytes, "models") - if !result.Exists() { - if idx+1 < len(baseURLs) { - log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) - continue - } - return fallbackAntigravityPrimaryModels() - } - - now := time.Now().Unix() - modelConfig := registry.GetAntigravityModelConfig() - models := make([]*registry.ModelInfo, 0, len(result.Map())) - for originalName, modelData := range result.Map() { - modelID := strings.TrimSpace(originalName) - if modelID == "" { - continue - } - switch modelID { - case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro": - continue - } - modelCfg := modelConfig[modelID] - - // Extract displayName from upstream response, fallback to modelID - displayName := modelData.Get("displayName").String() - if displayName == "" { - displayName = modelID - } - - modelInfo := ®istry.ModelInfo{ - ID: modelID, - Name: modelID, - Description: displayName, - DisplayName: displayName, - Version: modelID, - Object: "model", - Created: now, - OwnedBy: antigravityAuthType, - Type: antigravityAuthType, - } - - // Build input modalities from upstream capability flags. - inputModalities := []string{"TEXT"} - if modelData.Get("supportsImages").Bool() { - inputModalities = append(inputModalities, "IMAGE") - } - if modelData.Get("supportsVideo").Bool() { - inputModalities = append(inputModalities, "VIDEO") - } - modelInfo.SupportedInputModalities = inputModalities - modelInfo.SupportedOutputModalities = []string{"TEXT"} - - // Token limits from upstream. - if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 { - modelInfo.InputTokenLimit = int(maxTok) - } - if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 { - modelInfo.OutputTokenLimit = int(maxOut) - } - - // Supported generation methods (Gemini v1beta convention). - modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"} - - // Look up Thinking support from static config using upstream model name. - if modelCfg != nil { - if modelCfg.Thinking != nil { - modelInfo.Thinking = modelCfg.Thinking - } - if modelCfg.MaxCompletionTokens > 0 { - modelInfo.MaxCompletionTokens = modelCfg.MaxCompletionTokens - } - } - models = append(models, modelInfo) - } - if len(models) == 0 { - if idx+1 < len(baseURLs) { - log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) - continue - } - log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list") - return fallbackAntigravityPrimaryModels() - } - storeAntigravityPrimaryModels(models) - return models - } - return fallbackAntigravityPrimaryModels() -} - func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) { if auth == nil { return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"} diff --git a/internal/runtime/executor/antigravity_executor_models_cache_test.go b/internal/runtime/executor/antigravity_executor_models_cache_test.go deleted file mode 100644 index be49a7c1ac..0000000000 --- a/internal/runtime/executor/antigravity_executor_models_cache_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package executor - -import ( - "testing" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" -) - -func resetAntigravityPrimaryModelsCacheForTest() { - antigravityPrimaryModelsCache.mu.Lock() - antigravityPrimaryModelsCache.models = nil - antigravityPrimaryModelsCache.mu.Unlock() -} - -func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) { - resetAntigravityPrimaryModelsCacheForTest() - t.Cleanup(resetAntigravityPrimaryModelsCacheForTest) - - seed := []*registry.ModelInfo{ - {ID: "claude-sonnet-4-5"}, - {ID: "gemini-2.5-pro"}, - } - if updated := storeAntigravityPrimaryModels(seed); !updated { - t.Fatal("expected non-empty model list to update primary cache") - } - - if updated := storeAntigravityPrimaryModels(nil); updated { - t.Fatal("expected nil model list not to overwrite primary cache") - } - if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated { - t.Fatal("expected empty model list not to overwrite primary cache") - } - - got := loadAntigravityPrimaryModels() - if len(got) != 2 { - t.Fatalf("expected cached model count 2, got %d", len(got)) - } - if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" { - t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID) - } -} - -func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) { - resetAntigravityPrimaryModelsCacheForTest() - t.Cleanup(resetAntigravityPrimaryModelsCacheForTest) - - if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{ - ID: "gpt-5", - DisplayName: "GPT-5", - SupportedGenerationMethods: []string{"generateContent"}, - SupportedParameters: []string{"temperature"}, - Thinking: ®istry.ThinkingSupport{ - Levels: []string{"high"}, - }, - }}); !updated { - t.Fatal("expected model cache update") - } - - got := loadAntigravityPrimaryModels() - if len(got) != 1 { - t.Fatalf("expected one cached model, got %d", len(got)) - } - got[0].ID = "mutated-id" - if len(got[0].SupportedGenerationMethods) > 0 { - got[0].SupportedGenerationMethods[0] = "mutated-method" - } - if len(got[0].SupportedParameters) > 0 { - got[0].SupportedParameters[0] = "mutated-parameter" - } - if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 { - got[0].Thinking.Levels[0] = "mutated-level" - } - - again := loadAntigravityPrimaryModels() - if len(again) != 1 { - t.Fatalf("expected one cached model after mutation, got %d", len(again)) - } - if again[0].ID != "gpt-5" { - t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID) - } - if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" { - t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods) - } - if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" { - t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters) - } - if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" { - t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking) - } -} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 596db3dd8b..af31f86aad 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -282,8 +282,6 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A // IMPORTANT: Update coreManager FIRST, before model registration. // This ensures that configuration changes (proxy_url, prefix, etc.) take effect // immediately for API calls, rather than waiting for model registration to complete. - // Model registration may involve network calls (e.g., FetchAntigravityModels) that - // could timeout if the new proxy_url is unreachable. op := "register" var err error if existing, ok := s.coreManager.GetByID(auth.ID); ok { @@ -813,9 +811,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { models = registry.GetAIStudioModels() models = applyExcludedModels(models, excluded) case "antigravity": - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - models = executor.FetchAntigravityModels(ctx, a, s.cfg) - cancel() + models = registry.GetAntigravityModels() models = applyExcludedModels(models, excluded) case "claude": models = registry.GetClaudeModels() @@ -952,9 +948,6 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { key = strings.ToLower(strings.TrimSpace(a.Provider)) } GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) - if provider == "antigravity" { - s.backfillAntigravityModels(a, models) - } return } @@ -1099,56 +1092,6 @@ func (s *Service) oauthExcludedModels(provider, authKind string) []string { return cfg.OAuthExcludedModels[providerKey] } -func (s *Service) backfillAntigravityModels(source *coreauth.Auth, primaryModels []*ModelInfo) { - if s == nil || s.coreManager == nil || len(primaryModels) == 0 { - return - } - - sourceID := "" - if source != nil { - sourceID = strings.TrimSpace(source.ID) - } - - reg := registry.GetGlobalRegistry() - for _, candidate := range s.coreManager.List() { - if candidate == nil || candidate.Disabled { - continue - } - candidateID := strings.TrimSpace(candidate.ID) - if candidateID == "" || candidateID == sourceID { - continue - } - if !strings.EqualFold(strings.TrimSpace(candidate.Provider), "antigravity") { - continue - } - if len(reg.GetModelsForClient(candidateID)) > 0 { - continue - } - - authKind := strings.ToLower(strings.TrimSpace(candidate.Attributes["auth_kind"])) - if authKind == "" { - if kind, _ := candidate.AccountInfo(); strings.EqualFold(kind, "api_key") { - authKind = "apikey" - } - } - excluded := s.oauthExcludedModels("antigravity", authKind) - if candidate.Attributes != nil { - if val, ok := candidate.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" { - excluded = strings.Split(val, ",") - } - } - - models := applyExcludedModels(primaryModels, excluded) - models = applyOAuthModelAlias(s.cfg, "antigravity", authKind, models) - if len(models) == 0 { - continue - } - - reg.RegisterClient(candidateID, "antigravity", applyModelPrefixes(models, candidate.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) - log.Debugf("antigravity models backfilled for auth %s using primary model list", candidateID) - } -} - func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo { if len(models) == 0 || len(excluded) == 0 { return models diff --git a/sdk/cliproxy/service_antigravity_backfill_test.go b/sdk/cliproxy/service_antigravity_backfill_test.go deleted file mode 100644 index df087438ea..0000000000 --- a/sdk/cliproxy/service_antigravity_backfill_test.go +++ /dev/null @@ -1,135 +0,0 @@ -package cliproxy - -import ( - "context" - "strings" - "testing" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" - coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" -) - -func TestBackfillAntigravityModels_RegistersMissingAuth(t *testing.T) { - source := &coreauth.Auth{ - ID: "ag-backfill-source", - Provider: "antigravity", - Status: coreauth.StatusActive, - Attributes: map[string]string{ - "auth_kind": "oauth", - }, - } - target := &coreauth.Auth{ - ID: "ag-backfill-target", - Provider: "antigravity", - Status: coreauth.StatusActive, - Attributes: map[string]string{ - "auth_kind": "oauth", - }, - } - - manager := coreauth.NewManager(nil, nil, nil) - if _, err := manager.Register(context.Background(), source); err != nil { - t.Fatalf("register source auth: %v", err) - } - if _, err := manager.Register(context.Background(), target); err != nil { - t.Fatalf("register target auth: %v", err) - } - - service := &Service{ - cfg: &config.Config{}, - coreManager: manager, - } - - reg := registry.GetGlobalRegistry() - reg.UnregisterClient(source.ID) - reg.UnregisterClient(target.ID) - t.Cleanup(func() { - reg.UnregisterClient(source.ID) - reg.UnregisterClient(target.ID) - }) - - primary := []*ModelInfo{ - {ID: "claude-sonnet-4-5"}, - {ID: "gemini-2.5-pro"}, - } - reg.RegisterClient(source.ID, "antigravity", primary) - - service.backfillAntigravityModels(source, primary) - - got := reg.GetModelsForClient(target.ID) - if len(got) != 2 { - t.Fatalf("expected target auth to be backfilled with 2 models, got %d", len(got)) - } - - ids := make(map[string]struct{}, len(got)) - for _, model := range got { - if model == nil { - continue - } - ids[strings.ToLower(strings.TrimSpace(model.ID))] = struct{}{} - } - if _, ok := ids["claude-sonnet-4-5"]; !ok { - t.Fatal("expected backfilled model claude-sonnet-4-5") - } - if _, ok := ids["gemini-2.5-pro"]; !ok { - t.Fatal("expected backfilled model gemini-2.5-pro") - } -} - -func TestBackfillAntigravityModels_RespectsExcludedModels(t *testing.T) { - source := &coreauth.Auth{ - ID: "ag-backfill-source-excluded", - Provider: "antigravity", - Status: coreauth.StatusActive, - Attributes: map[string]string{ - "auth_kind": "oauth", - }, - } - target := &coreauth.Auth{ - ID: "ag-backfill-target-excluded", - Provider: "antigravity", - Status: coreauth.StatusActive, - Attributes: map[string]string{ - "auth_kind": "oauth", - "excluded_models": "gemini-2.5-pro", - }, - } - - manager := coreauth.NewManager(nil, nil, nil) - if _, err := manager.Register(context.Background(), source); err != nil { - t.Fatalf("register source auth: %v", err) - } - if _, err := manager.Register(context.Background(), target); err != nil { - t.Fatalf("register target auth: %v", err) - } - - service := &Service{ - cfg: &config.Config{}, - coreManager: manager, - } - - reg := registry.GetGlobalRegistry() - reg.UnregisterClient(source.ID) - reg.UnregisterClient(target.ID) - t.Cleanup(func() { - reg.UnregisterClient(source.ID) - reg.UnregisterClient(target.ID) - }) - - primary := []*ModelInfo{ - {ID: "claude-sonnet-4-5"}, - {ID: "gemini-2.5-pro"}, - } - reg.RegisterClient(source.ID, "antigravity", primary) - - service.backfillAntigravityModels(source, primary) - - got := reg.GetModelsForClient(target.ID) - if len(got) != 1 { - t.Fatalf("expected 1 model after exclusion, got %d", len(got)) - } - if got[0] == nil || !strings.EqualFold(strings.TrimSpace(got[0].ID), "claude-sonnet-4-5") { - t.Fatalf("expected remaining model %q, got %+v", "claude-sonnet-4-5", got[0]) - } -} From ec24baf757dbd03ad29092a7c5e302aa010e927b Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Thu, 12 Mar 2026 10:21:09 +0800 Subject: [PATCH 09/13] feat(fetch_antigravity_models): add command to fetch and save Antigravity model list --- cmd/fetch_antigravity_models/main.go | 275 +++++++++++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100644 cmd/fetch_antigravity_models/main.go diff --git a/cmd/fetch_antigravity_models/main.go b/cmd/fetch_antigravity_models/main.go new file mode 100644 index 0000000000..0cf45d3b3b --- /dev/null +++ b/cmd/fetch_antigravity_models/main.go @@ -0,0 +1,275 @@ +// Command fetch_antigravity_models connects to the Antigravity API using the +// stored auth credentials and saves the dynamically fetched model list to a +// JSON file for inspection or offline use. +// +// Usage: +// +// go run ./cmd/fetch_antigravity_models [flags] +// +// Flags: +// +// --auths-dir Directory containing auth JSON files (default: "auths") +// --output Output JSON file path (default: "antigravity_models.json") +// --pretty Pretty-print the output JSON (default: true) +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +const ( + antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com" + antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com" + antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com" + antigravityModelsPath = "/v1internal:fetchAvailableModels" +) + +func init() { + logging.SetupBaseLogger() + log.SetLevel(log.InfoLevel) +} + +// modelOutput wraps the fetched model list with fetch metadata. +type modelOutput struct { + Models []modelEntry `json:"models"` +} + +// modelEntry contains only the fields we want to keep for static model definitions. +type modelEntry struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Type string `json:"type"` + DisplayName string `json:"display_name"` + Name string `json:"name"` + Description string `json:"description"` + ContextLength int `json:"context_length,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` +} + +func main() { + var authsDir string + var outputPath string + var pretty bool + + flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files") + flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path") + flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON") + flag.Parse() + + // Resolve relative paths against the working directory. + wd, err := os.Getwd() + if err != nil { + fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err) + os.Exit(1) + } + if !filepath.IsAbs(authsDir) { + authsDir = filepath.Join(wd, authsDir) + } + if !filepath.IsAbs(outputPath) { + outputPath = filepath.Join(wd, outputPath) + } + + fmt.Printf("Scanning auth files in: %s\n", authsDir) + + // Load all auth records from the directory. + fileStore := sdkauth.NewFileTokenStore() + fileStore.SetBaseDir(authsDir) + + ctx := context.Background() + auths, err := fileStore.List(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err) + os.Exit(1) + } + if len(auths) == 0 { + fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir) + os.Exit(1) + } + + // Find the first enabled antigravity auth. + var chosen *coreauth.Auth + for _, a := range auths { + if a == nil || a.Disabled { + continue + } + if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") { + chosen = a + break + } + } + if chosen == nil { + fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir) + os.Exit(1) + } + + fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label) + + // Fetch models from the upstream Antigravity API. + fmt.Println("Fetching Antigravity model list from upstream...") + + fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + models := fetchModels(fetchCtx, chosen) + if len(models) == 0 { + fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)") + } else { + fmt.Printf("Fetched %d models.\n", len(models)) + } + + // Build the output payload. + out := modelOutput{ + Models: models, + } + + // Marshal to JSON. + var raw []byte + if pretty { + raw, err = json.MarshalIndent(out, "", " ") + } else { + raw, err = json.Marshal(out) + } + if err != nil { + fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err) + os.Exit(1) + } + + if err = os.WriteFile(outputPath, raw, 0o644); err != nil { + fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err) + os.Exit(1) + } + + fmt.Printf("Model list saved to: %s\n", outputPath) +} + +func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry { + accessToken := metaStringValue(auth.Metadata, "access_token") + if accessToken == "" { + fmt.Fprintln(os.Stderr, "error: no access token found in auth") + return nil + } + + baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily} + + for _, baseURL := range baseURLs { + modelsURL := baseURL + antigravityModelsPath + + var payload []byte + if auth != nil && auth.Metadata != nil { + if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" { + payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid))) + } + } + if len(payload) == 0 { + payload = []byte(`{}`) + } + + httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload))) + if errReq != nil { + continue + } + httpReq.Close = true + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64") + + httpClient := &http.Client{Timeout: 30 * time.Second} + if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil { + httpClient.Transport = transport + } + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + continue + } + + bodyBytes, errRead := io.ReadAll(httpResp.Body) + httpResp.Body.Close() + if errRead != nil { + continue + } + + if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + continue + } + + result := gjson.GetBytes(bodyBytes, "models") + if !result.Exists() { + continue + } + + var models []modelEntry + + for originalName, modelData := range result.Map() { + modelID := strings.TrimSpace(originalName) + if modelID == "" { + continue + } + // Skip internal/experimental models + switch modelID { + case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro": + continue + } + + displayName := modelData.Get("displayName").String() + if displayName == "" { + displayName = modelID + } + + entry := modelEntry{ + ID: modelID, + Object: "model", + OwnedBy: "antigravity", + Type: "antigravity", + DisplayName: displayName, + Name: modelID, + Description: displayName, + } + + if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 { + entry.ContextLength = int(maxTok) + } + if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 { + entry.MaxCompletionTokens = int(maxOut) + } + + models = append(models, entry) + } + + return models + } + + return nil +} + +func metaStringValue(m map[string]interface{}, key string) string { + if m == nil { + return "" + } + v, ok := m[key] + if !ok { + return "" + } + switch val := v.(type) { + case string: + return val + default: + return "" + } +} From dbd42a42b29beb1238fdfaa65ae0ef1a29b0d529 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Thu, 12 Mar 2026 10:32:04 +0800 Subject: [PATCH 10/13] fix(model_updater): clarify log message for model refresh failure --- internal/registry/model_updater.go | 2 +- internal/registry/models/models.json | 18 ------------------ 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go index 8775ca3598..36d2dd32ae 100644 --- a/internal/registry/model_updater.go +++ b/internal/registry/model_updater.go @@ -100,7 +100,7 @@ func tryRefreshModels(ctx context.Context) { log.Infof("models updated from %s", url) return } - log.Warn("models refresh failed from all URLs, using current data") + log.Warn("models refresh failed from all URLs, using local data") } func loadModelsFromBytes(data []byte, source string) error { diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index 545b476c9a..9a30478801 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -2628,24 +2628,6 @@ ] } }, - { - "id": "gemini-3.1-flash-lite-preview", - "object": "model", - "owned_by": "antigravity", - "type": "antigravity", - "display_name": "Gemini 3.1 Flash Lite Preview", - "name": "gemini-3.1-flash-lite-preview", - "description": "Gemini 3.1 Flash Lite Preview", - "thinking": { - "min": 128, - "max": 32768, - "dynamic_allowed": true, - "levels": [ - "minimal", - "high" - ] - } - }, { "id": "gemini-3.1-pro-high", "object": "model", From 0ac52da460ae2f8b2ed174d2db0105e338365a1a Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 12 Mar 2026 10:50:46 +0800 Subject: [PATCH 11/13] chore(ci): update model catalog fetch method in release workflow --- .github/workflows/release.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 30cdbeab93..3e65352366 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -17,7 +17,9 @@ jobs: with: fetch-depth: 0 - name: Refresh models catalog - run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json + run: | + git fetch --depth 1 https://github.com/router-for-me/models.git main + git show FETCH_HEAD:models.json > internal/registry/models/models.json - run: git fetch --force --tags - uses: actions/setup-go@v4 with: From 5484489406f3c5fc022402b9ba712b9e3ba06f8b Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 12 Mar 2026 11:19:24 +0800 Subject: [PATCH 12/13] chore(ci): update model catalog fetch method in workflows --- .github/workflows/docker-image.yml | 8 ++++++-- .github/workflows/pr-test-build.yml | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 4a9501c090..9c8c2858d7 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -16,7 +16,9 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Refresh models catalog - run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json + run: | + git fetch --depth 1 https://github.com/router-for-me/models.git main + git show FETCH_HEAD:models.json > internal/registry/models/models.json - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to DockerHub @@ -49,7 +51,9 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Refresh models catalog - run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json + run: | + git fetch --depth 1 https://github.com/router-for-me/models.git main + git show FETCH_HEAD:models.json > internal/registry/models/models.json - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to DockerHub diff --git a/.github/workflows/pr-test-build.yml b/.github/workflows/pr-test-build.yml index b24b1fcbc4..75f4c520a5 100644 --- a/.github/workflows/pr-test-build.yml +++ b/.github/workflows/pr-test-build.yml @@ -13,7 +13,9 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Refresh models catalog - run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json + run: | + git fetch --depth 1 https://github.com/router-for-me/models.git main + git show FETCH_HEAD:models.json > internal/registry/models/models.json - name: Set up Go uses: actions/setup-go@v5 with: From c3d5dbe96f00919cbed27a52dce4b9b51c2c6141 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Fri, 13 Mar 2026 10:56:39 +0800 Subject: [PATCH 13/13] feat(model_registry): enhance model registration and refresh mechanisms --- internal/registry/model_registry.go | 16 +- internal/registry/model_updater.go | 219 ++++++++++++++++++++++++++-- sdk/cliproxy/service.go | 100 ++++++++++++- 3 files changed, 312 insertions(+), 23 deletions(-) diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 8f56c43d01..74ad6acf18 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -187,6 +187,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) { } const defaultModelRegistryHookTimeout = 5 * time.Second +const modelQuotaExceededWindow = 5 * time.Minute func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) { hook := r.hook @@ -388,6 +389,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ reg.InfoByProvider[provider] = cloneModelInfo(model) } reg.LastUpdated = now + // Re-registering an existing client/model binding starts a fresh registry + // snapshot for that binding. Cooldown and suspension are transient + // scheduling state and must not survive this reconciliation step. if reg.QuotaExceededClients != nil { delete(reg.QuotaExceededClients, clientID) } @@ -781,7 +785,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) { models := make([]map[string]any, 0, len(r.models)) - quotaExpiredDuration := 5 * time.Minute var expiresAt time.Time for _, registration := range r.models { @@ -792,7 +795,7 @@ func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time. if quotaTime == nil { continue } - recoveryAt := quotaTime.Add(quotaExpiredDuration) + recoveryAt := quotaTime.Add(modelQuotaExceededWindow) if now.Before(recoveryAt) { expiredClients++ if expiresAt.IsZero() || recoveryAt.Before(expiresAt) { @@ -927,7 +930,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn return nil } - quotaExpiredDuration := 5 * time.Minute now := time.Now() result := make([]*ModelInfo, 0, len(providerModels)) @@ -949,7 +951,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider { continue } - if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { + if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow { expiredClients++ } } @@ -1003,12 +1005,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int { if registration, exists := r.models[modelID]; exists { now := time.Now() - quotaExpiredDuration := 5 * time.Minute // Count clients that have exceeded quota but haven't recovered yet expiredClients := 0 for _, quotaTime := range registration.QuotaExceededClients { - if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { + if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow { expiredClients++ } } @@ -1217,12 +1218,11 @@ func (r *ModelRegistry) CleanupExpiredQuotas() { defer r.mutex.Unlock() now := time.Now() - quotaExpiredDuration := 5 * time.Minute invalidated := false for modelID, registration := range r.models { for clientID, quotaTime := range registration.QuotaExceededClients { - if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration { + if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow { delete(registration.QuotaExceededClients, clientID) invalidated = true log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID) diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go index 36d2dd32ae..197f604492 100644 --- a/internal/registry/model_updater.go +++ b/internal/registry/model_updater.go @@ -15,7 +15,8 @@ import ( ) const ( - modelsFetchTimeout = 30 * time.Second + modelsFetchTimeout = 30 * time.Second + modelsRefreshInterval = 3 * time.Hour ) var modelsURLs = []string{ @@ -35,6 +36,34 @@ var modelsCatalogStore = &modelStore{} var updaterOnce sync.Once +// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes. +// changedProviders contains the provider names whose model definitions changed. +type ModelRefreshCallback func(changedProviders []string) + +var ( + refreshCallbackMu sync.Mutex + refreshCallback ModelRefreshCallback + pendingRefreshChanges []string +) + +// SetModelRefreshCallback registers a callback that is invoked when startup or +// periodic model refresh detects changes. Only one callback is supported; +// subsequent calls replace the previous callback. +func SetModelRefreshCallback(cb ModelRefreshCallback) { + refreshCallbackMu.Lock() + refreshCallback = cb + var pending []string + if cb != nil && len(pendingRefreshChanges) > 0 { + pending = append([]string(nil), pendingRefreshChanges...) + pendingRefreshChanges = nil + } + refreshCallbackMu.Unlock() + + if cb != nil && len(pending) > 0 { + cb(pending) + } +} + func init() { // Load embedded data as fallback on startup. if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil { @@ -42,23 +71,76 @@ func init() { } } -// StartModelsUpdater runs a one-time models refresh on startup. -// It blocks until the startup fetch attempt finishes so service initialization -// can wait for the refreshed catalog before registering auth-backed models. -// Safe to call multiple times; only one refresh will run. +// StartModelsUpdater starts a background updater that fetches models +// immediately on startup and then refreshes the model catalog every 3 hours. +// Safe to call multiple times; only one updater will run. func StartModelsUpdater(ctx context.Context) { updaterOnce.Do(func() { - runModelsUpdater(ctx) + go runModelsUpdater(ctx) }) } func runModelsUpdater(ctx context.Context) { - // Try network fetch once on startup, then stop. - // Periodic refresh is disabled - models are only refreshed at startup. - tryRefreshModels(ctx) + tryStartupRefresh(ctx) + periodicRefresh(ctx) +} + +func periodicRefresh(ctx context.Context) { + ticker := time.NewTicker(modelsRefreshInterval) + defer ticker.Stop() + log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + tryPeriodicRefresh(ctx) + } + } +} + +// tryPeriodicRefresh fetches models from remote, compares with the current +// catalog, and notifies the registered callback if any provider changed. +func tryPeriodicRefresh(ctx context.Context) { + tryRefreshModels(ctx, "periodic model refresh") +} + +// tryStartupRefresh fetches models from remote in the background during +// process startup. It uses the same change detection as periodic refresh so +// existing auth registrations can be updated after the callback is registered. +func tryStartupRefresh(ctx context.Context) { + tryRefreshModels(ctx, "startup model refresh") +} + +func tryRefreshModels(ctx context.Context, label string) { + oldData := getModels() + + parsed, url := fetchModelsFromRemote(ctx) + if parsed == nil { + log.Warnf("%s: fetch failed from all URLs, keeping current data", label) + return + } + + // Detect changes before updating store. + changed := detectChangedProviders(oldData, parsed) + + // Update store with new data regardless. + modelsCatalogStore.mu.Lock() + modelsCatalogStore.data = parsed + modelsCatalogStore.mu.Unlock() + + if len(changed) == 0 { + log.Infof("%s completed from %s, no changes detected", label, url) + return + } + + log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed) + notifyModelRefresh(changed) } -func tryRefreshModels(ctx context.Context) { +// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog +// along with the URL it was fetched from. Returns (nil, "") if all fetches fail. +func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) { client := &http.Client{Timeout: modelsFetchTimeout} for _, url := range modelsURLs { reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout) @@ -92,15 +174,126 @@ func tryRefreshModels(ctx context.Context) { continue } - if err := loadModelsFromBytes(data, url); err != nil { + var parsed staticModelsJSON + if err := json.Unmarshal(data, &parsed); err != nil { log.Warnf("models parse failed from %s: %v", url, err) continue } + if err := validateModelsCatalog(&parsed); err != nil { + log.Warnf("models validate failed from %s: %v", url, err) + continue + } + + return &parsed, url + } + return nil, "" +} - log.Infof("models updated from %s", url) +// detectChangedProviders compares two model catalogs and returns provider names +// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped +// under a single "codex" provider. +func detectChangedProviders(oldData, newData *staticModelsJSON) []string { + if oldData == nil || newData == nil { + return nil + } + + type section struct { + provider string + oldList []*ModelInfo + newList []*ModelInfo + } + + sections := []section{ + {"claude", oldData.Claude, newData.Claude}, + {"gemini", oldData.Gemini, newData.Gemini}, + {"vertex", oldData.Vertex, newData.Vertex}, + {"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI}, + {"aistudio", oldData.AIStudio, newData.AIStudio}, + {"codex", oldData.CodexFree, newData.CodexFree}, + {"codex", oldData.CodexTeam, newData.CodexTeam}, + {"codex", oldData.CodexPlus, newData.CodexPlus}, + {"codex", oldData.CodexPro, newData.CodexPro}, + {"qwen", oldData.Qwen, newData.Qwen}, + {"iflow", oldData.IFlow, newData.IFlow}, + {"kimi", oldData.Kimi, newData.Kimi}, + {"antigravity", oldData.Antigravity, newData.Antigravity}, + } + + seen := make(map[string]bool, len(sections)) + var changed []string + for _, s := range sections { + if seen[s.provider] { + continue + } + if modelSectionChanged(s.oldList, s.newList) { + changed = append(changed, s.provider) + seen[s.provider] = true + } + } + return changed +} + +// modelSectionChanged reports whether two model slices differ. +func modelSectionChanged(a, b []*ModelInfo) bool { + if len(a) != len(b) { + return true + } + if len(a) == 0 { + return false + } + aj, err1 := json.Marshal(a) + bj, err2 := json.Marshal(b) + if err1 != nil || err2 != nil { + return true + } + return string(aj) != string(bj) +} + +func notifyModelRefresh(changedProviders []string) { + if len(changedProviders) == 0 { return } - log.Warn("models refresh failed from all URLs, using local data") + + refreshCallbackMu.Lock() + cb := refreshCallback + if cb == nil { + pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders) + refreshCallbackMu.Unlock() + return + } + refreshCallbackMu.Unlock() + cb(changedProviders) +} + +func mergeProviderNames(existing, incoming []string) []string { + if len(incoming) == 0 { + return existing + } + seen := make(map[string]struct{}, len(existing)+len(incoming)) + merged := make([]string, 0, len(existing)+len(incoming)) + for _, provider := range existing { + name := strings.ToLower(strings.TrimSpace(provider)) + if name == "" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + merged = append(merged, name) + } + for _, provider := range incoming { + name := strings.ToLower(strings.TrimSpace(provider)) + if name == "" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + merged = append(merged, name) + } + return merged } func loadModelsFromBytes(data []byte, source string) error { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index af31f86aad..abe1deed5f 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -434,6 +434,17 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace } } +func (s *Service) registerResolvedModelsForAuth(a *coreauth.Auth, providerKey string, models []*ModelInfo) { + if a == nil || a.ID == "" { + return + } + if len(models) == 0 { + GlobalModelRegistry().UnregisterClient(a.ID) + return + } + GlobalModelRegistry().RegisterClient(a.ID, providerKey, models) +} + // rebindExecutors refreshes provider executors so they observe the latest configuration. func (s *Service) rebindExecutors() { if s == nil || s.coreManager == nil { @@ -541,6 +552,44 @@ func (s *Service) Run(ctx context.Context) error { s.hooks.OnBeforeStart(s.cfg) } + // Register callback for startup and periodic model catalog refresh. + // When remote model definitions change, re-register models for affected providers. + // This intentionally rebuilds per-auth model availability from the latest catalog + // snapshot instead of preserving prior registry suppression state. + registry.SetModelRefreshCallback(func(changedProviders []string) { + if s == nil || s.coreManager == nil || len(changedProviders) == 0 { + return + } + + providerSet := make(map[string]bool, len(changedProviders)) + for _, p := range changedProviders { + providerSet[strings.ToLower(strings.TrimSpace(p))] = true + } + + auths := s.coreManager.List() + refreshed := 0 + for _, item := range auths { + if item == nil || item.ID == "" { + continue + } + auth, ok := s.coreManager.GetByID(item.ID) + if !ok || auth == nil || auth.Disabled { + continue + } + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + if !providerSet[provider] { + continue + } + if s.refreshModelRegistrationForAuth(auth) { + refreshed++ + } + } + + if refreshed > 0 { + log.Infof("re-registered models for %d auth(s) due to model catalog changes: %v", refreshed, changedProviders) + } + }) + s.serverErr = make(chan error, 1) go func() { if errStart := s.server.Start(); errStart != nil { @@ -926,7 +975,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { if providerKey == "" { providerKey = "openai-compatibility" } - GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix)) + s.registerResolvedModelsForAuth(a, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix)) } else { // Ensure stale registrations are cleared when model list becomes empty. GlobalModelRegistry().UnregisterClient(a.ID) @@ -947,13 +996,60 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { if key == "" { key = strings.ToLower(strings.TrimSpace(a.Provider)) } - GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) + s.registerResolvedModelsForAuth(a, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) return } GlobalModelRegistry().UnregisterClient(a.ID) } +// refreshModelRegistrationForAuth re-applies the latest model registration for +// one auth and reconciles any concurrent auth changes that race with the +// refresh. Callers are expected to pre-filter provider membership. +// +// Re-registration is deliberate: registry cooldown/suspension state is treated +// as part of the previous registration snapshot and is cleared when the auth is +// rebound to the refreshed model catalog. +func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool { + if s == nil || s.coreManager == nil || current == nil || current.ID == "" { + return false + } + + if !current.Disabled { + s.ensureExecutorsForAuth(current) + } + s.registerModelsForAuth(current) + + latest, ok := s.latestAuthForModelRegistration(current.ID) + if !ok || latest.Disabled { + GlobalModelRegistry().UnregisterClient(current.ID) + s.coreManager.RefreshSchedulerEntry(current.ID) + return false + } + + // Re-apply the latest auth snapshot so concurrent auth updates cannot leave + // stale model registrations behind. This may duplicate registration work when + // no auth fields changed, but keeps the refresh path simple and correct. + s.ensureExecutorsForAuth(latest) + s.registerModelsForAuth(latest) + s.coreManager.RefreshSchedulerEntry(current.ID) + return true +} + +// latestAuthForModelRegistration returns the latest auth snapshot regardless of +// provider membership. Callers use this after a registration attempt to restore +// whichever state currently owns the client ID in the global registry. +func (s *Service) latestAuthForModelRegistration(authID string) (*coreauth.Auth, bool) { + if s == nil || s.coreManager == nil || authID == "" { + return nil, false + } + auth, ok := s.coreManager.GetByID(authID) + if !ok || auth == nil || auth.ID == "" { + return nil, false + } + return auth, true +} + func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey { if auth == nil || s.cfg == nil { return nil