diff --git a/internal/api/handlers/management/api_tools.go b/internal/api/handlers/management/api_tools.go index de546ea820..6c994c2cd0 100644 --- a/internal/api/handlers/management/api_tools.go +++ b/internal/api/handlers/management/api_tools.go @@ -425,10 +425,13 @@ func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth * } auth.Metadata["type"] = "antigravity" - if h != nil && h.authManager != nil { - auth.LastRefreshedAt = now - auth.UpdatedAt = now - _, _ = h.authManager.Update(ctx, auth) + if h != nil { + state, err := h.runtimeSnapshot() + if err == nil && state.authManager != nil { + auth.LastRefreshedAt = now + auth.UpdatedAt = now + _, _ = state.authManager.Update(ctx, auth) + } } return strings.TrimSpace(tokenResp.AccessToken), nil @@ -614,10 +617,14 @@ func tokenValueFromMetadata(metadata map[string]any) string { func (h *Handler) authByIndex(authIndex string) *coreauth.Auth { authIndex = strings.TrimSpace(authIndex) - if authIndex == "" || h == nil || h.authManager == nil { + if authIndex == "" || h == nil { + return nil + } + state, err := h.runtimeSnapshot() + if err != nil || state.authManager == nil { return nil } - auths := h.authManager.List() + auths := state.authManager.List() for _, auth := range auths { if auth == nil { continue @@ -637,9 +644,12 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { proxyCandidates = append(proxyCandidates, proxyStr) } } - if h != nil && h.cfg != nil { - if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" { - proxyCandidates = append(proxyCandidates, proxyStr) + if h != nil { + state, err := h.runtimeSnapshot() + if err == nil && state.cfg != nil { + if proxyStr := strings.TrimSpace(state.cfg.ProxyURL); proxyStr != "" { + proxyCandidates = append(proxyCandidates, proxyStr) + } } } diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 4d1ec44cf2..f9522a0593 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -221,17 +221,18 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) { } func (h *Handler) managementCallbackURL(path string) (string, error) { - if h == nil || h.cfg == nil || h.cfg.Port <= 0 { + snapshot, err := h.runtimeSnapshot() + if h == nil || err != nil || snapshot.cfg == nil || snapshot.cfg.Port <= 0 { return "", fmt.Errorf("server port is not configured") } if !strings.HasPrefix(path, "/") { path = "/" + path } scheme := "http" - if h.cfg.TLS.Enable { + if snapshot.cfg.TLS.Enable { scheme = "https" } - return fmt.Sprintf("%s://127.0.0.1:%d%s", scheme, h.cfg.Port, path), nil + return fmt.Sprintf("%s://127.0.0.1:%d%s", scheme, snapshot.cfg.Port, path), nil } func (h *Handler) ListAuthFiles(c *gin.Context) { @@ -239,11 +240,16 @@ func (h *Handler) ListAuthFiles(c *gin.Context) { c.JSON(500, gin.H{"error": "handler not initialized"}) return } - if h.authManager == nil { + snapshot, err := h.runtimeSnapshot() + if err != nil { + c.JSON(500, gin.H{"error": "failed to snapshot runtime state"}) + return + } + if snapshot.authManager == nil { h.listAuthFilesFromDisk(c) return } - auths := h.authManager.List() + auths := snapshot.authManager.List() files := make([]gin.H, 0, len(auths)) for _, auth := range auths { if entry := h.buildAuthFileEntry(auth); entry != nil { @@ -268,8 +274,8 @@ func (h *Handler) GetAuthFileModels(c *gin.Context) { // Try to find auth ID via authManager var authID string - if h.authManager != nil { - auths := h.authManager.List() + if snapshot, err := h.runtimeSnapshot(); err == nil && snapshot.authManager != nil { + auths := snapshot.authManager.List() for _, auth := range auths { if auth.FileName == name || auth.ID == name { authID = auth.ID @@ -308,7 +314,12 @@ func (h *Handler) GetAuthFileModels(c *gin.Context) { // List auth files from disk when the auth manager is unavailable. func (h *Handler) listAuthFilesFromDisk(c *gin.Context) { - entries, err := os.ReadDir(h.cfg.AuthDir) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(500, gin.H{"error": "failed to resolve auth dir"}) + return + } + entries, err := os.ReadDir(cfg.AuthDir) if err != nil { c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)}) return @@ -326,7 +337,7 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) { fileData := gin.H{"name": name, "size": info.Size(), "modtime": info.ModTime()} // Read file to get type field - full := filepath.Join(h.cfg.AuthDir, name) + full := filepath.Join(cfg.AuthDir, name) if data, errRead := os.ReadFile(full); errRead == nil { typeValue := gjson.GetBytes(data, "type").String() emailValue := gjson.GetBytes(data, "email").String() @@ -538,6 +549,51 @@ func isRuntimeOnlyAuth(auth *coreauth.Auth) bool { return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true") } +func (h *Handler) authManagerSnapshot() (*coreauth.Manager, error) { + snapshot, err := h.runtimeSnapshot() + if err != nil { + return nil, err + } + return snapshot.authManager, nil +} + +func (h *Handler) authDirSnapshot() (string, error) { + cfg, err := h.configSnapshot() + if err != nil { + return "", err + } + if cfg == nil { + return "", nil + } + return cfg.AuthDir, nil +} + +func findAuthByName(manager *coreauth.Manager, name string, matchPath bool) *coreauth.Auth { + if manager == nil { + return nil + } + name = strings.TrimSpace(name) + if name == "" { + return nil + } + if auth, ok := manager.GetByID(name); ok { + return auth + } + auths := manager.List() + for _, auth := range auths { + if auth == nil { + continue + } + if strings.TrimSpace(auth.FileName) == name { + return auth + } + if matchPath && filepath.Base(strings.TrimSpace(authAttribute(auth, "path"))) == name { + return auth + } + } + return nil +} + // Download single auth file by name func (h *Handler) DownloadAuthFile(c *gin.Context) { name := c.Query("name") @@ -549,7 +605,12 @@ func (h *Handler) DownloadAuthFile(c *gin.Context) { c.JSON(400, gin.H{"error": "name must end with .json"}) return } - full := filepath.Join(h.cfg.AuthDir, name) + authDir, err := h.authDirSnapshot() + if err != nil || strings.TrimSpace(authDir) == "" { + c.JSON(500, gin.H{"error": "auth dir unavailable"}) + return + } + full := filepath.Join(authDir, name) data, err := os.ReadFile(full) if err != nil { if os.IsNotExist(err) { @@ -565,7 +626,13 @@ func (h *Handler) DownloadAuthFile(c *gin.Context) { // Upload auth file: multipart or raw JSON with ?name= func (h *Handler) UploadAuthFile(c *gin.Context) { - if h.authManager == nil { + manager, err := h.authManagerSnapshot() + if err != nil || manager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) + return + } + authDir, err := h.authDirSnapshot() + if err != nil || strings.TrimSpace(authDir) == "" { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) return } @@ -576,7 +643,7 @@ func (h *Handler) UploadAuthFile(c *gin.Context) { c.JSON(400, gin.H{"error": "file must be .json"}) return } - dst := filepath.Join(h.cfg.AuthDir, name) + dst := filepath.Join(authDir, name) if !filepath.IsAbs(dst) { if abs, errAbs := filepath.Abs(dst); errAbs == nil { dst = abs @@ -612,7 +679,7 @@ func (h *Handler) UploadAuthFile(c *gin.Context) { c.JSON(400, gin.H{"error": "failed to read body"}) return } - dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) + dst := filepath.Join(authDir, filepath.Base(name)) if !filepath.IsAbs(dst) { if abs, errAbs := filepath.Abs(dst); errAbs == nil { dst = abs @@ -631,13 +698,19 @@ func (h *Handler) UploadAuthFile(c *gin.Context) { // Delete auth files: single by name or all func (h *Handler) DeleteAuthFile(c *gin.Context) { - if h.authManager == nil { + manager, err := h.authManagerSnapshot() + if err != nil || manager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) + return + } + authDir, err := h.authDirSnapshot() + if err != nil || strings.TrimSpace(authDir) == "" { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) return } ctx := c.Request.Context() if all := c.Query("all"); all == "true" || all == "1" || all == "*" { - entries, err := os.ReadDir(h.cfg.AuthDir) + entries, err := os.ReadDir(authDir) if err != nil { c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)}) return @@ -651,7 +724,7 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) { if !strings.HasSuffix(strings.ToLower(name), ".json") { continue } - full := filepath.Join(h.cfg.AuthDir, name) + full := filepath.Join(authDir, name) if !filepath.IsAbs(full) { if abs, errAbs := filepath.Abs(full); errAbs == nil { full = abs @@ -675,9 +748,9 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) { return } - targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) + targetPath := filepath.Join(authDir, filepath.Base(name)) targetID := "" - if targetAuth := h.findAuthForDelete(name); targetAuth != nil { + if targetAuth := findAuthByName(manager, name, true); targetAuth != nil { targetID = strings.TrimSpace(targetAuth.ID) if path := strings.TrimSpace(authAttribute(targetAuth, "path")); path != "" { targetPath = path @@ -709,29 +782,11 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) { } func (h *Handler) findAuthForDelete(name string) *coreauth.Auth { - if h == nil || h.authManager == nil { - return nil - } - name = strings.TrimSpace(name) - if name == "" { + manager, err := h.authManagerSnapshot() + if h == nil || err != nil || manager == nil { return nil } - if auth, ok := h.authManager.GetByID(name); ok { - return auth - } - auths := h.authManager.List() - for _, auth := range auths { - if auth == nil { - continue - } - if strings.TrimSpace(auth.FileName) == name { - return auth - } - if filepath.Base(strings.TrimSpace(authAttribute(auth, "path"))) == name { - return auth - } - } - return nil + return findAuthByName(manager, name, true) } func (h *Handler) authIDForPath(path string) string { @@ -740,8 +795,8 @@ func (h *Handler) authIDForPath(path string) string { return "" } id := path - if h != nil && h.cfg != nil { - authDir := strings.TrimSpace(h.cfg.AuthDir) + if authDir, err := h.authDirSnapshot(); err == nil { + authDir = strings.TrimSpace(authDir) if authDir != "" { if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" { id = rel @@ -756,7 +811,8 @@ func (h *Handler) authIDForPath(path string) string { } func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error { - if h.authManager == nil { + manager, err := h.authManagerSnapshot() + if err != nil || manager == nil { return nil } if path == "" { @@ -805,23 +861,24 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data [] if hasLastRefresh { auth.LastRefreshedAt = lastRefresh } - if existing, ok := h.authManager.GetByID(authID); ok { + if existing, ok := manager.GetByID(authID); ok { auth.CreatedAt = existing.CreatedAt if !hasLastRefresh { auth.LastRefreshedAt = existing.LastRefreshedAt } auth.NextRefreshAfter = existing.NextRefreshAfter auth.Runtime = existing.Runtime - _, err := h.authManager.Update(ctx, auth) + _, err := manager.Update(ctx, auth) return err } - _, err := h.authManager.Register(ctx, auth) + _, err = manager.Register(ctx, auth) return err } // PatchAuthFileStatus toggles the disabled state of an auth file func (h *Handler) PatchAuthFileStatus(c *gin.Context) { - if h.authManager == nil { + manager, err := h.authManagerSnapshot() + if err != nil || manager == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) return } @@ -849,17 +906,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) { // Find auth by name or ID var targetAuth *coreauth.Auth - if auth, ok := h.authManager.GetByID(name); ok { - targetAuth = auth - } else { - auths := h.authManager.List() - for _, auth := range auths { - if auth.FileName == name { - targetAuth = auth - break - } - } - } + targetAuth = findAuthByName(manager, name, false) if targetAuth == nil { c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"}) @@ -877,7 +924,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) { } targetAuth.UpdatedAt = time.Now() - if _, err := h.authManager.Update(ctx, targetAuth); err != nil { + if _, err := manager.Update(ctx, targetAuth); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)}) return } @@ -887,7 +934,8 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) { // PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file. func (h *Handler) PatchAuthFileFields(c *gin.Context) { - if h.authManager == nil { + manager, err := h.authManagerSnapshot() + if err != nil || manager == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) return } @@ -914,17 +962,7 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) { // Find auth by name or ID var targetAuth *coreauth.Auth - if auth, ok := h.authManager.GetByID(name); ok { - targetAuth = auth - } else { - auths := h.authManager.List() - for _, auth := range auths { - if auth.FileName == name { - targetAuth = auth - break - } - } - } + targetAuth = findAuthByName(manager, name, false) if targetAuth == nil { c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"}) @@ -977,7 +1015,7 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) { targetAuth.UpdatedAt = time.Now() - if _, err := h.authManager.Update(ctx, targetAuth); err != nil { + if _, err := manager.Update(ctx, targetAuth); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)}) return } @@ -986,31 +1024,32 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) { } func (h *Handler) disableAuth(ctx context.Context, id string) { - if h == nil || h.authManager == nil { + manager, err := h.authManagerSnapshot() + if h == nil || err != nil || manager == nil { return } id = strings.TrimSpace(id) if id == "" { return } - if auth, ok := h.authManager.GetByID(id); ok { + if auth, ok := manager.GetByID(id); ok { auth.Disabled = true auth.Status = coreauth.StatusDisabled auth.StatusMessage = "removed via management API" auth.UpdatedAt = time.Now() - _, _ = h.authManager.Update(ctx, auth) + _, _ = manager.Update(ctx, auth) return } authID := h.authIDForPath(id) if authID == "" { return } - if auth, ok := h.authManager.GetByID(authID); ok { + if auth, ok := manager.GetByID(authID); ok { auth.Disabled = true auth.Status = coreauth.StatusDisabled auth.StatusMessage = "removed via management API" auth.UpdatedAt = time.Now() - _, _ = h.authManager.Update(ctx, auth) + _, _ = manager.Update(ctx, auth) } } @@ -1029,14 +1068,22 @@ func (h *Handler) tokenStoreWithBaseDir() coreauth.Store { if h == nil { return nil } - store := h.tokenStore + snapshot, err := h.runtimeSnapshot() + if err != nil { + return nil + } + store := snapshot.tokenStore if store == nil { store = sdkAuth.GetTokenStore() - h.tokenStore = store + h.stateMu.Lock() + if h.tokenStore == nil { + h.tokenStore = store + } + h.stateMu.Unlock() } - if h.cfg != nil { + if snapshot.cfg != nil { if dirSetter, ok := store.(interface{ SetBaseDir(string) }); ok { - dirSetter.SetBaseDir(h.cfg.AuthDir) + dirSetter.SetBaseDir(snapshot.cfg.AuthDir) } } return store @@ -1050,8 +1097,12 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s if store == nil { return "", fmt.Errorf("token store unavailable") } - if h.postAuthHook != nil { - if err := h.postAuthHook(ctx, record); err != nil { + snapshot, err := h.runtimeSnapshot() + if err != nil { + return "", err + } + if snapshot.postAuthHook != nil { + if err := snapshot.postAuthHook(ctx, record); err != nil { return "", fmt.Errorf("post-auth hook failed: %w", err) } } @@ -1061,6 +1112,11 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s func (h *Handler) RequestAnthropicToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } fmt.Println("Initializing Claude authentication...") @@ -1081,7 +1137,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } // Initialize Claude auth service - anthropicAuth := claude.NewClaudeAuth(h.cfg) + anthropicAuth := claude.NewClaudeAuth(cfg) // Generate authorization URL (then override redirect_uri to reuse server port) authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes) @@ -1091,7 +1147,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { return } - RegisterOAuthSession(state, "anthropic") + authDir := h.registerOAuthSession(state, "anthropic") isWebUI := isWebUIRequest(c) var forwarder *callbackForwarder @@ -1116,7 +1172,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } // Helper: wait for callback file - waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state)) + waitFile := filepath.Join(authDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state)) waitForFile := func(path string, timeout time.Duration) (map[string]string, error) { deadline := time.Now().Add(timeout) for { @@ -1206,7 +1262,12 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) - proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } + proxyHTTPClient := util.SetProxy(&cfg.SDKConfig, &http.Client{}) ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient) // Optional project ID from query @@ -1227,7 +1288,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - RegisterOAuthSession(state, "gemini") + authDir := h.registerOAuthSession(state, "gemini") isWebUI := isWebUIRequest(c) var forwarder *callbackForwarder @@ -1252,7 +1313,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } // Wait for callback file written by server route - waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state)) + waitFile := filepath.Join(authDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state)) fmt.Println("Waiting for authentication callback...") deadline := time.Now().Add(5 * time.Minute) var authCode string @@ -1356,7 +1417,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { // Initialize authenticated HTTP client via GeminiAuth to honor proxy settings gemAuth := geminiAuth.NewGeminiAuth() - gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{ + gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, cfg, &geminiAuth.WebLoginOptions{ NoBrowser: true, }) if errGetClient != nil { @@ -1465,6 +1526,11 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { func (h *Handler) RequestCodexToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } fmt.Println("Initializing Codex authentication...") @@ -1485,7 +1551,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } // Initialize Codex auth service - openaiAuth := codex.NewCodexAuth(h.cfg) + openaiAuth := codex.NewCodexAuth(cfg) // Generate authorization URL authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes) @@ -1495,7 +1561,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { return } - RegisterOAuthSession(state, "codex") + authDir := h.registerOAuthSession(state, "codex") isWebUI := isWebUIRequest(c) var forwarder *callbackForwarder @@ -1520,7 +1586,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } // Wait for callback file - waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state)) + waitFile := filepath.Join(authDir, fmt.Sprintf(".oauth-codex-%s.oauth", state)) deadline := time.Now().Add(5 * time.Minute) var code string for { @@ -1611,10 +1677,15 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { func (h *Handler) RequestAntigravityToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } fmt.Println("Initializing Antigravity authentication...") - authSvc := antigravity.NewAntigravityAuth(h.cfg, nil) + authSvc := antigravity.NewAntigravityAuth(cfg, nil) state, errState := misc.GenerateRandomState() if errState != nil { @@ -1626,7 +1697,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort) authURL := authSvc.BuildAuthURL(state, redirectURI) - RegisterOAuthSession(state, "antigravity") + authDir := h.registerOAuthSession(state, "antigravity") isWebUI := isWebUIRequest(c) var forwarder *callbackForwarder @@ -1650,7 +1721,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder) } - waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) + waitFile := filepath.Join(authDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) deadline := time.Now().Add(5 * time.Minute) var authCode string for { @@ -1776,12 +1847,17 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { func (h *Handler) RequestQwenToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } fmt.Println("Initializing Qwen authentication...") state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) // Initialize Qwen auth service - qwenAuth := qwen.NewQwenAuth(h.cfg) + qwenAuth := qwen.NewQwenAuth(cfg) // Generate authorization URL deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx) @@ -1792,7 +1868,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { } authURL := deviceFlow.VerificationURIComplete - RegisterOAuthSession(state, "qwen") + h.registerOAuthSession(state, "qwen") go func() { fmt.Println("Waiting for authentication...") @@ -1832,12 +1908,17 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { func (h *Handler) RequestKimiToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } fmt.Println("Initializing Kimi authentication...") state := fmt.Sprintf("kmi-%d", time.Now().UnixNano()) // Initialize Kimi auth service - kimiAuth := kimi.NewKimiAuth(h.cfg) + kimiAuth := kimi.NewKimiAuth(cfg) // Generate authorization URL deviceFlow, errStartDeviceFlow := kimiAuth.StartDeviceFlow(ctx) @@ -1851,7 +1932,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) { authURL = deviceFlow.VerificationURI } - RegisterOAuthSession(state, "kimi") + h.registerOAuthSession(state, "kimi") go func() { fmt.Println("Waiting for authentication...") @@ -1909,14 +1990,19 @@ func (h *Handler) RequestKimiToken(c *gin.Context) { func (h *Handler) RequestIFlowToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "configuration unavailable"}) + return + } fmt.Println("Initializing iFlow authentication...") state := fmt.Sprintf("ifl-%d", time.Now().UnixNano()) - authSvc := iflowauth.NewIFlowAuth(h.cfg) + authSvc := iflowauth.NewIFlowAuth(cfg) authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort) - RegisterOAuthSession(state, "iflow") + authDir := h.registerOAuthSession(state, "iflow") isWebUI := isWebUIRequest(c) var forwarder *callbackForwarder @@ -1941,7 +2027,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } fmt.Println("Waiting for authentication...") - waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state)) + waitFile := filepath.Join(authDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state)) deadline := time.Now().Add(5 * time.Minute) var resultMap map[string]string for { @@ -2022,6 +2108,11 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { func (h *Handler) RequestIFlowCookieToken(c *gin.Context) { ctx := context.Background() + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "configuration unavailable"}) + return + } var payload struct { Cookie string `json:"cookie"` @@ -2046,7 +2137,7 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) { // Check for duplicate BXAuth before authentication bxAuth := iflowauth.ExtractBXAuth(cookieValue) - if existingFile, err := iflowauth.CheckDuplicateBXAuth(h.cfg.AuthDir, bxAuth); err != nil { + if existingFile, err := iflowauth.CheckDuplicateBXAuth(cfg.AuthDir, bxAuth); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to check duplicate"}) return } else if existingFile != "" { @@ -2055,7 +2146,7 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) { return } - authSvc := iflowauth.NewIFlowAuth(h.cfg) + authSvc := iflowauth.NewIFlowAuth(cfg) tokenData, errAuth := authSvc.AuthenticateWithCookie(ctx, cookieValue) if errAuth != nil { c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": errAuth.Error()}) diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index f77e91e9ba..c900db9db4 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -24,11 +24,16 @@ const ( ) func (h *Handler) GetConfig(c *gin.Context) { - if h == nil || h.cfg == nil { - c.JSON(200, gin.H{}) + cfg, err := h.configSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "config snapshot failed", "message": err.Error()}) + return + } + if cfg == nil { + c.JSON(http.StatusOK, gin.H{}) return } - c.JSON(200, new(*h.cfg)) + c.JSON(http.StatusOK, cfg) } type releaseInfo struct { @@ -39,9 +44,14 @@ type releaseInfo struct { // GetLatestVersion returns the latest release version from GitHub without downloading assets. func (h *Handler) GetLatestVersion(c *gin.Context) { client := &http.Client{Timeout: 10 * time.Second} + cfg, err := h.configSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "config snapshot failed", "message": err.Error()}) + return + } proxyURL := "" - if h != nil && h.cfg != nil { - proxyURL = strings.TrimSpace(h.cfg.ProxyURL) + if cfg != nil { + proxyURL = strings.TrimSpace(cfg.ProxyURL) } if proxyURL != "" { sdkCfg := &sdkconfig.SDKConfig{ProxyURL: proxyURL} @@ -93,7 +103,7 @@ func (h *Handler) GetLatestVersion(c *gin.Context) { func WriteConfig(path string, data []byte) error { data = config.NormalizeCommentIndentation(data) - f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) if err != nil { return err } @@ -114,13 +124,18 @@ func (h *Handler) PutConfigYAML(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": "cannot read request body"}) return } - var cfg config.Config - if err = yaml.Unmarshal(body, &cfg); err != nil { + var parsed config.Config + if err = yaml.Unmarshal(body, &parsed); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": err.Error()}) return } - // Validate config using LoadConfigOptional with optional=false to enforce parsing - tmpDir := filepath.Dir(h.configFilePath) + snapshot, err := h.runtimeSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "config snapshot failed", "message": err.Error()}) + return + } + + tmpDir := filepath.Dir(snapshot.configFilePath) tmpFile, err := os.CreateTemp(tmpDir, "config-validate-*.yaml") if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()}) @@ -141,31 +156,52 @@ func (h *Handler) PutConfigYAML(c *gin.Context) { defer func() { _ = os.Remove(tempFile) }() - _, err = config.LoadConfigOptional(tempFile, false) + + validatedCfg, err := config.LoadConfigOptional(tempFile, false) if err != nil { c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "invalid_config", "message": err.Error()}) return } + if err := config.SaveConfigPreserveComments(tempFile, validatedCfg); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()}) + return + } + normalizedBody, err := os.ReadFile(tempFile) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "read_failed", "message": err.Error()}) + return + } + h.mu.Lock() defer h.mu.Unlock() - if WriteConfig(h.configFilePath, body) != nil { + if err := WriteConfig(snapshot.configFilePath, normalizedBody); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": "failed to write config"}) return } - // Reload into handler to keep memory in sync - newCfg, err := config.LoadConfig(h.configFilePath) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "reload_failed", "message": err.Error()}) + if _, err := h.reloadCommittedConfig(snapshot); err != nil { + message := err.Error() + switch { + case strings.HasPrefix(message, "failed to reload config: "): + c.JSON(http.StatusInternalServerError, gin.H{"error": "reload_failed", "message": strings.TrimPrefix(message, "failed to reload config: ")}) + case strings.HasPrefix(message, "failed to apply runtime config: "): + c.JSON(http.StatusInternalServerError, gin.H{"error": "runtime_apply_failed", "message": strings.TrimPrefix(message, "failed to apply runtime config: ")}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": message}) + } return } - h.cfg = newCfg c.JSON(http.StatusOK, gin.H{"ok": true, "changed": []string{"config"}}) } // GetConfigYAML returns the raw config.yaml file bytes without re-encoding. // It preserves comments and original formatting/styles. func (h *Handler) GetConfigYAML(c *gin.Context) { - data, err := os.ReadFile(h.configFilePath) + snapshot, err := h.runtimeSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "config snapshot failed", "message": err.Error()}) + return + } + data, err := os.ReadFile(snapshot.configFilePath) if err != nil { if os.IsNotExist(err) { c.JSON(http.StatusNotFound, gin.H{"error": "not_found", "message": "config file not found"}) @@ -177,106 +213,168 @@ func (h *Handler) GetConfigYAML(c *gin.Context) { c.Header("Content-Type", "application/yaml; charset=utf-8") c.Header("Cache-Control", "no-store") c.Header("X-Content-Type-Options", "nosniff") - // Write raw bytes as-is _, _ = c.Writer.Write(data) } // Debug -func (h *Handler) GetDebug(c *gin.Context) { c.JSON(200, gin.H{"debug": h.cfg.Debug}) } -func (h *Handler) PutDebug(c *gin.Context) { h.updateBoolField(c, func(v bool) { h.cfg.Debug = v }) } +func (h *Handler) GetDebug(c *gin.Context) { + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"debug": false}) + return + } + c.JSON(http.StatusOK, gin.H{"debug": cfg.Debug}) +} + +func (h *Handler) PutDebug(c *gin.Context) { + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.Debug = v }) +} // UsageStatisticsEnabled func (h *Handler) GetUsageStatisticsEnabled(c *gin.Context) { - c.JSON(200, gin.H{"usage-statistics-enabled": h.cfg.UsageStatisticsEnabled}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"usage-statistics-enabled": false}) + return + } + c.JSON(http.StatusOK, gin.H{"usage-statistics-enabled": cfg.UsageStatisticsEnabled}) } + func (h *Handler) PutUsageStatisticsEnabled(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.UsageStatisticsEnabled = v }) + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.UsageStatisticsEnabled = v }) } -// UsageStatisticsEnabled +// LoggingToFile func (h *Handler) GetLoggingToFile(c *gin.Context) { - c.JSON(200, gin.H{"logging-to-file": h.cfg.LoggingToFile}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"logging-to-file": false}) + return + } + c.JSON(http.StatusOK, gin.H{"logging-to-file": cfg.LoggingToFile}) } + func (h *Handler) PutLoggingToFile(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v }) + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.LoggingToFile = v }) } // LogsMaxTotalSizeMB func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) { - c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB}) -} -func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) { - var body struct { - Value *int `json:"value"` - } - if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"logs-max-total-size-mb": 0}) return } - value := *body.Value - if value < 0 { - value = 0 - } - h.cfg.LogsMaxTotalSizeMB = value - h.persist(c) + c.JSON(http.StatusOK, gin.H{"logs-max-total-size-mb": cfg.LogsMaxTotalSizeMB}) +} + +func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) { + h.updateIntField(c, func(cfg *config.Config, value int) error { + if value < 0 { + value = 0 + } + if value > config.MaxLogsMaxTotalSizeMB { + return fmt.Errorf("logs-max-total-size-mb exceeds allowed maximum") + } + cfg.LogsMaxTotalSizeMB = value + return nil + }) } // ErrorLogsMaxFiles func (h *Handler) GetErrorLogsMaxFiles(c *gin.Context) { - c.JSON(200, gin.H{"error-logs-max-files": h.cfg.ErrorLogsMaxFiles}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"error-logs-max-files": 0}) + return + } + c.JSON(http.StatusOK, gin.H{"error-logs-max-files": cfg.ErrorLogsMaxFiles}) } + func (h *Handler) PutErrorLogsMaxFiles(c *gin.Context) { - var body struct { - Value *int `json:"value"` - } - if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + h.updateIntField(c, func(cfg *config.Config, value int) error { + if value < 0 { + value = 10 + } + cfg.ErrorLogsMaxFiles = value + return nil + }) +} + +// Request log +func (h *Handler) GetRequestLog(c *gin.Context) { + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"request-log": false}) return } - value := *body.Value - if value < 0 { - value = 10 - } - h.cfg.ErrorLogsMaxFiles = value - h.persist(c) + c.JSON(http.StatusOK, gin.H{"request-log": cfg.RequestLog}) } -// Request log -func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) } func (h *Handler) PutRequestLog(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v }) + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.RequestLog = v }) } // Websocket auth func (h *Handler) GetWebsocketAuth(c *gin.Context) { - c.JSON(200, gin.H{"ws-auth": h.cfg.WebsocketAuth}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"ws-auth": false}) + return + } + c.JSON(http.StatusOK, gin.H{"ws-auth": cfg.WebsocketAuth}) } + func (h *Handler) PutWebsocketAuth(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.WebsocketAuth = v }) + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.WebsocketAuth = v }) } // Request retry func (h *Handler) GetRequestRetry(c *gin.Context) { - c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"request-retry": 0}) + return + } + c.JSON(http.StatusOK, gin.H{"request-retry": cfg.RequestRetry}) } + func (h *Handler) PutRequestRetry(c *gin.Context) { - h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v }) + h.updateIntField(c, func(cfg *config.Config, v int) error { + cfg.RequestRetry = v + return nil + }) } // Max retry interval func (h *Handler) GetMaxRetryInterval(c *gin.Context) { - c.JSON(200, gin.H{"max-retry-interval": h.cfg.MaxRetryInterval}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"max-retry-interval": 0}) + return + } + c.JSON(http.StatusOK, gin.H{"max-retry-interval": cfg.MaxRetryInterval}) } + func (h *Handler) PutMaxRetryInterval(c *gin.Context) { - h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v }) + h.updateIntField(c, func(cfg *config.Config, v int) error { + cfg.MaxRetryInterval = v + return nil + }) } // ForceModelPrefix func (h *Handler) GetForceModelPrefix(c *gin.Context) { - c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"force-model-prefix": false}) + return + } + c.JSON(http.StatusOK, gin.H{"force-model-prefix": cfg.ForceModelPrefix}) } + func (h *Handler) PutForceModelPrefix(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v }) + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.ForceModelPrefix = v }) } func normalizeRoutingStrategy(strategy string) (string, bool) { @@ -293,13 +391,19 @@ func normalizeRoutingStrategy(strategy string) (string, bool) { // RoutingStrategy func (h *Handler) GetRoutingStrategy(c *gin.Context) { - strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"strategy": ""}) + return + } + strategy, ok := normalizeRoutingStrategy(cfg.Routing.Strategy) if !ok { - c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)}) + c.JSON(http.StatusOK, gin.H{"strategy": strings.TrimSpace(cfg.Routing.Strategy)}) return } - c.JSON(200, gin.H{"strategy": strategy}) + c.JSON(http.StatusOK, gin.H{"strategy": strategy}) } + func (h *Handler) PutRoutingStrategy(c *gin.Context) { var body struct { Value *string `json:"value"` @@ -313,16 +417,29 @@ func (h *Handler) PutRoutingStrategy(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"}) return } - h.cfg.Routing.Strategy = normalized - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + cfg.Routing.Strategy = normalized + return nil + }) } // Proxy URL -func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) } +func (h *Handler) GetProxyURL(c *gin.Context) { + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(http.StatusOK, gin.H{"proxy-url": ""}) + return + } + c.JSON(http.StatusOK, gin.H{"proxy-url": cfg.ProxyURL}) +} + func (h *Handler) PutProxyURL(c *gin.Context) { - h.updateStringField(c, func(v string) { h.cfg.ProxyURL = v }) + h.updateStringField(c, func(cfg *config.Config, v string) { cfg.ProxyURL = v }) } + func (h *Handler) DeleteProxyURL(c *gin.Context) { - h.cfg.ProxyURL = "" - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + cfg.ProxyURL = "" + return nil + }) } diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 083d4e31ef..3232d4aa6e 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -9,32 +9,82 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" ) -// Generic helpers for list[string] -func (h *Handler) putStringList(c *gin.Context, set func([]string), after func()) { +func (h *Handler) configSnapshotOrEmpty() *config.Config { + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + return &config.Config{} + } + return cfg +} + +func decodeJSONItems[T any](c *gin.Context) ([]T, bool) { data, err := c.GetRawData() if err != nil { c.JSON(400, gin.H{"error": "failed to read body"}) - return + return nil, false } - var arr []string - if err = json.Unmarshal(data, &arr); err != nil { - var obj struct { - Items []string `json:"items"` - } - if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 { - c.JSON(400, gin.H{"error": "invalid body"}) - return + + var items []T + if err = json.Unmarshal(data, &items); err == nil { + return items, true + } + + var wrapped struct { + Items []T `json:"items"` + } + if err = json.Unmarshal(data, &wrapped); err != nil || len(wrapped.Items) == 0 { + c.JSON(400, gin.H{"error": "invalid body"}) + return nil, false + } + return wrapped.Items, true +} + +func findIndexByIndexOrMatch[T any](items []T, index *int, match *string, matches func(T, string) bool) int { + if index != nil && *index >= 0 && *index < len(items) { + return *index + } + if match == nil { + return -1 + } + + needle := strings.TrimSpace(*match) + if needle == "" { + return -1 + } + for i := range items { + if matches(items[i], needle) { + return i } - arr = obj.Items } - set(arr) - if after != nil { - after() + return -1 +} + +func (h *Handler) ampCodeSnapshot() config.AmpCode { + return h.configSnapshotOrEmpty().AmpCode +} + +func (h *Handler) mutateAmpCode(c *gin.Context, mutate func(*config.AmpCode) error) { + h.applyConfigMutation(c, func(cfg *config.Config) error { + return mutate(&cfg.AmpCode) + }) +} + +// Generic helpers for list[string] +func (h *Handler) putStringList(c *gin.Context, set func(*config.Config, []string), after func(*config.Config)) { + arr, ok := decodeJSONItems[string](c) + if !ok { + return } - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + set(cfg, arr) + if after != nil { + after(cfg) + } + return nil + }) } -func (h *Handler) patchStringList(c *gin.Context, target *[]string, after func()) { +func (h *Handler) patchStringList(c *gin.Context, target func(*config.Config) *[]string, after func(*config.Config)) { var body struct { Old *string `json:"old"` New *string `json:"new"` @@ -45,103 +95,106 @@ func (h *Handler) patchStringList(c *gin.Context, target *[]string, after func() c.JSON(400, gin.H{"error": "invalid body"}) return } - if body.Index != nil && body.Value != nil && *body.Index >= 0 && *body.Index < len(*target) { - (*target)[*body.Index] = *body.Value - if after != nil { - after() - } - h.persist(c) - return - } - if body.Old != nil && body.New != nil { - for i := range *target { - if (*target)[i] == *body.Old { - (*target)[i] = *body.New - if after != nil { - after() + h.applyConfigMutation(c, func(cfg *config.Config) error { + items := target(cfg) + if body.Index != nil && body.Value != nil && *body.Index >= 0 && *body.Index < len(*items) { + (*items)[*body.Index] = *body.Value + if after != nil { + after(cfg) + } + return nil + } + if body.Old != nil && body.New != nil { + for i := range *items { + if (*items)[i] == *body.Old { + (*items)[i] = *body.New + if after != nil { + after(cfg) + } + return nil } - h.persist(c) - return } + *items = append(*items, *body.New) + if after != nil { + after(cfg) + } + return nil } - *target = append(*target, *body.New) - if after != nil { - after() - } - h.persist(c) - return - } - c.JSON(400, gin.H{"error": "missing fields"}) + return fmt.Errorf("missing fields") + }) } -func (h *Handler) deleteFromStringList(c *gin.Context, target *[]string, after func()) { +func (h *Handler) deleteFromStringList(c *gin.Context, target func(*config.Config) *[]string, after func(*config.Config)) { if idxStr := c.Query("index"); idxStr != "" { var idx int _, err := fmt.Sscanf(idxStr, "%d", &idx) - if err == nil && idx >= 0 && idx < len(*target) { - *target = append((*target)[:idx], (*target)[idx+1:]...) - if after != nil { - after() - } - h.persist(c) + if err == nil { + h.applyConfigMutation(c, func(cfg *config.Config) error { + items := target(cfg) + if idx < 0 || idx >= len(*items) { + return fmt.Errorf("missing index or value") + } + *items = append((*items)[:idx], (*items)[idx+1:]...) + if after != nil { + after(cfg) + } + return nil + }) return } } if val := strings.TrimSpace(c.Query("value")); val != "" { - out := make([]string, 0, len(*target)) - for _, v := range *target { - if strings.TrimSpace(v) != val { - out = append(out, v) + h.applyConfigMutation(c, func(cfg *config.Config) error { + items := target(cfg) + out := make([]string, 0, len(*items)) + for _, v := range *items { + if strings.TrimSpace(v) != val { + out = append(out, v) + } } - } - *target = out - if after != nil { - after() - } - h.persist(c) + *items = out + if after != nil { + after(cfg) + } + return nil + }) return } c.JSON(400, gin.H{"error": "missing index or value"}) } // api-keys -func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.cfg.APIKeys}) } +func (h *Handler) GetAPIKeys(c *gin.Context) { + cfg := h.configSnapshotOrEmpty() + c.JSON(200, gin.H{"api-keys": cfg.APIKeys}) +} func (h *Handler) PutAPIKeys(c *gin.Context) { - h.putStringList(c, func(v []string) { - h.cfg.APIKeys = append([]string(nil), v...) + h.putStringList(c, func(cfg *config.Config, v []string) { + cfg.APIKeys = append([]string(nil), v...) }, nil) } func (h *Handler) PatchAPIKeys(c *gin.Context) { - h.patchStringList(c, &h.cfg.APIKeys, func() {}) + h.patchStringList(c, func(cfg *config.Config) *[]string { return &cfg.APIKeys }, nil) } func (h *Handler) DeleteAPIKeys(c *gin.Context) { - h.deleteFromStringList(c, &h.cfg.APIKeys, func() {}) + h.deleteFromStringList(c, func(cfg *config.Config) *[]string { return &cfg.APIKeys }, nil) } // gemini-api-key: []GeminiKey func (h *Handler) GetGeminiKeys(c *gin.Context) { - c.JSON(200, gin.H{"gemini-api-key": h.cfg.GeminiKey}) + cfg := h.configSnapshotOrEmpty() + c.JSON(200, gin.H{"gemini-api-key": cfg.GeminiKey}) } func (h *Handler) PutGeminiKeys(c *gin.Context) { - data, err := c.GetRawData() - if err != nil { - c.JSON(400, gin.H{"error": "failed to read body"}) + arr, ok := decodeJSONItems[config.GeminiKey](c) + if !ok { return } - var arr []config.GeminiKey - if err = json.Unmarshal(data, &arr); err != nil { - var obj struct { - Items []config.GeminiKey `json:"items"` - } - if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 { - c.JSON(400, gin.H{"error": "invalid body"}) - return - } - arr = obj.Items - } - h.cfg.GeminiKey = append([]config.GeminiKey(nil), arr...) - h.cfg.SanitizeGeminiKeys() - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + cfg.GeminiKey = append([]config.GeminiKey(nil), arr...) + cfg.SanitizeGeminiKeys() + return nil + }) } func (h *Handler) PatchGeminiKey(c *gin.Context) { type geminiKeyPatch struct { @@ -161,80 +214,73 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } - targetIndex := -1 - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { - targetIndex = *body.Index - } - if targetIndex == -1 && body.Match != nil { - match := strings.TrimSpace(*body.Match) - if match != "" { - for i := range h.cfg.GeminiKey { - if h.cfg.GeminiKey[i].APIKey == match { - targetIndex = i - break - } + h.applyConfigMutation(c, func(cfg *config.Config) error { + targetIndex := findIndexByIndexOrMatch(cfg.GeminiKey, body.Index, body.Match, func(item config.GeminiKey, match string) bool { + return item.APIKey == match + }) + if targetIndex == -1 { + return fmt.Errorf("item not found") + } + entry := cfg.GeminiKey[targetIndex] + if body.Value.APIKey != nil { + trimmed := strings.TrimSpace(*body.Value.APIKey) + if trimmed == "" { + cfg.GeminiKey = append(cfg.GeminiKey[:targetIndex], cfg.GeminiKey[targetIndex+1:]...) + cfg.SanitizeGeminiKeys() + return nil } + entry.APIKey = trimmed } - } - if targetIndex == -1 { - c.JSON(404, gin.H{"error": "item not found"}) - return - } - - entry := h.cfg.GeminiKey[targetIndex] - if body.Value.APIKey != nil { - trimmed := strings.TrimSpace(*body.Value.APIKey) - if trimmed == "" { - h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...) - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) } - entry.APIKey = trimmed - } - if body.Value.Prefix != nil { - entry.Prefix = strings.TrimSpace(*body.Value.Prefix) - } - if body.Value.BaseURL != nil { - entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) - } - if body.Value.ProxyURL != nil { - entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) - } - if body.Value.Headers != nil { - entry.Headers = config.NormalizeHeaders(*body.Value.Headers) - } - if body.Value.ExcludedModels != nil { - entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) - } - h.cfg.GeminiKey[targetIndex] = entry - h.cfg.SanitizeGeminiKeys() - h.persist(c) + if body.Value.BaseURL != nil { + entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + cfg.GeminiKey[targetIndex] = entry + cfg.SanitizeGeminiKeys() + return nil + }) } func (h *Handler) DeleteGeminiKey(c *gin.Context) { if val := strings.TrimSpace(c.Query("api-key")); val != "" { - out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) - for _, v := range h.cfg.GeminiKey { - if v.APIKey != val { - out = append(out, v) + h.applyConfigMutation(c, func(cfg *config.Config) error { + out := make([]config.GeminiKey, 0, len(cfg.GeminiKey)) + for _, v := range cfg.GeminiKey { + if v.APIKey != val { + out = append(out, v) + } } - } - if len(out) != len(h.cfg.GeminiKey) { - h.cfg.GeminiKey = out - h.cfg.SanitizeGeminiKeys() - h.persist(c) - } else { - c.JSON(404, gin.H{"error": "item not found"}) - } + if len(out) == len(cfg.GeminiKey) { + return fmt.Errorf("item not found") + } + cfg.GeminiKey = out + cfg.SanitizeGeminiKeys() + return nil + }) return } if idxStr := c.Query("index"); idxStr != "" { var idx int - if _, err := fmt.Sscanf(idxStr, "%d", &idx); err == nil && idx >= 0 && idx < len(h.cfg.GeminiKey) { - h.cfg.GeminiKey = append(h.cfg.GeminiKey[:idx], h.cfg.GeminiKey[idx+1:]...) - h.cfg.SanitizeGeminiKeys() - h.persist(c) + if _, err := fmt.Sscanf(idxStr, "%d", &idx); err == nil { + h.applyConfigMutation(c, func(cfg *config.Config) error { + if idx < 0 || idx >= len(cfg.GeminiKey) { + return fmt.Errorf("missing api-key or index") + } + cfg.GeminiKey = append(cfg.GeminiKey[:idx], cfg.GeminiKey[idx+1:]...) + cfg.SanitizeGeminiKeys() + return nil + }) return } } @@ -243,31 +289,22 @@ func (h *Handler) DeleteGeminiKey(c *gin.Context) { // claude-api-key: []ClaudeKey func (h *Handler) GetClaudeKeys(c *gin.Context) { - c.JSON(200, gin.H{"claude-api-key": h.cfg.ClaudeKey}) + cfg := h.configSnapshotOrEmpty() + c.JSON(200, gin.H{"claude-api-key": cfg.ClaudeKey}) } func (h *Handler) PutClaudeKeys(c *gin.Context) { - data, err := c.GetRawData() - if err != nil { - c.JSON(400, gin.H{"error": "failed to read body"}) + arr, ok := decodeJSONItems[config.ClaudeKey](c) + if !ok { return } - var arr []config.ClaudeKey - if err = json.Unmarshal(data, &arr); err != nil { - var obj struct { - Items []config.ClaudeKey `json:"items"` - } - if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 { - c.JSON(400, gin.H{"error": "invalid body"}) - return - } - arr = obj.Items - } for i := range arr { normalizeClaudeKey(&arr[i]) } - h.cfg.ClaudeKey = arr - h.cfg.SanitizeClaudeKeys() - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + cfg.ClaudeKey = arr + cfg.SanitizeClaudeKeys() + return nil + }) } func (h *Handler) PatchClaudeKey(c *gin.Context) { type claudeKeyPatch struct { @@ -288,72 +325,69 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } - targetIndex := -1 - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) { - targetIndex = *body.Index - } - if targetIndex == -1 && body.Match != nil { - match := strings.TrimSpace(*body.Match) - for i := range h.cfg.ClaudeKey { - if h.cfg.ClaudeKey[i].APIKey == match { - targetIndex = i - break - } + h.applyConfigMutation(c, func(cfg *config.Config) error { + targetIndex := findIndexByIndexOrMatch(cfg.ClaudeKey, body.Index, body.Match, func(item config.ClaudeKey, match string) bool { + return item.APIKey == match + }) + if targetIndex == -1 { + return fmt.Errorf("item not found") } - } - if targetIndex == -1 { - c.JSON(404, gin.H{"error": "item not found"}) - return - } - - entry := h.cfg.ClaudeKey[targetIndex] - if body.Value.APIKey != nil { - entry.APIKey = strings.TrimSpace(*body.Value.APIKey) - } - if body.Value.Prefix != nil { - entry.Prefix = strings.TrimSpace(*body.Value.Prefix) - } - if body.Value.BaseURL != nil { - entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) - } - if body.Value.ProxyURL != nil { - entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) - } - if body.Value.Models != nil { - entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...) - } - if body.Value.Headers != nil { - entry.Headers = config.NormalizeHeaders(*body.Value.Headers) - } - if body.Value.ExcludedModels != nil { - entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) - } - normalizeClaudeKey(&entry) - h.cfg.ClaudeKey[targetIndex] = entry - h.cfg.SanitizeClaudeKeys() - h.persist(c) + entry := cfg.ClaudeKey[targetIndex] + if body.Value.APIKey != nil { + entry.APIKey = strings.TrimSpace(*body.Value.APIKey) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Models != nil { + entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + normalizeClaudeKey(&entry) + cfg.ClaudeKey[targetIndex] = entry + cfg.SanitizeClaudeKeys() + return nil + }) } func (h *Handler) DeleteClaudeKey(c *gin.Context) { if val := c.Query("api-key"); val != "" { - out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey)) - for _, v := range h.cfg.ClaudeKey { - if v.APIKey != val { - out = append(out, v) + h.applyConfigMutation(c, func(cfg *config.Config) error { + out := make([]config.ClaudeKey, 0, len(cfg.ClaudeKey)) + for _, v := range cfg.ClaudeKey { + if v.APIKey != val { + out = append(out, v) + } } - } - h.cfg.ClaudeKey = out - h.cfg.SanitizeClaudeKeys() - h.persist(c) + cfg.ClaudeKey = out + cfg.SanitizeClaudeKeys() + return nil + }) return } if idxStr := c.Query("index"); idxStr != "" { var idx int _, err := fmt.Sscanf(idxStr, "%d", &idx) - if err == nil && idx >= 0 && idx < len(h.cfg.ClaudeKey) { - h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:idx], h.cfg.ClaudeKey[idx+1:]...) - h.cfg.SanitizeClaudeKeys() - h.persist(c) + if err == nil { + h.applyConfigMutation(c, func(cfg *config.Config) error { + if idx < 0 || idx >= len(cfg.ClaudeKey) { + return fmt.Errorf("missing api-key or index") + } + cfg.ClaudeKey = append(cfg.ClaudeKey[:idx], cfg.ClaudeKey[idx+1:]...) + cfg.SanitizeClaudeKeys() + return nil + }) return } } @@ -362,25 +396,14 @@ func (h *Handler) DeleteClaudeKey(c *gin.Context) { // openai-compatibility: []OpenAICompatibility func (h *Handler) GetOpenAICompat(c *gin.Context) { - c.JSON(200, gin.H{"openai-compatibility": normalizedOpenAICompatibilityEntries(h.cfg.OpenAICompatibility)}) + cfg := h.configSnapshotOrEmpty() + c.JSON(200, gin.H{"openai-compatibility": normalizedOpenAICompatibilityEntries(cfg.OpenAICompatibility)}) } func (h *Handler) PutOpenAICompat(c *gin.Context) { - data, err := c.GetRawData() - if err != nil { - c.JSON(400, gin.H{"error": "failed to read body"}) + arr, ok := decodeJSONItems[config.OpenAICompatibility](c) + if !ok { return } - var arr []config.OpenAICompatibility - if err = json.Unmarshal(data, &arr); err != nil { - var obj struct { - Items []config.OpenAICompatibility `json:"items"` - } - if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 { - c.JSON(400, gin.H{"error": "invalid body"}) - return - } - arr = obj.Items - } filtered := make([]config.OpenAICompatibility, 0, len(arr)) for i := range arr { normalizeOpenAICompatibilityEntry(&arr[i]) @@ -388,9 +411,11 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) { filtered = append(filtered, arr[i]) } } - h.cfg.OpenAICompatibility = filtered - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + cfg.OpenAICompatibility = filtered + cfg.SanitizeOpenAICompatibility() + return nil + }) } func (h *Handler) PatchOpenAICompat(c *gin.Context) { type openAICompatPatch struct { @@ -410,76 +435,72 @@ func (h *Handler) PatchOpenAICompat(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } - targetIndex := -1 - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { - targetIndex = *body.Index - } - if targetIndex == -1 && body.Name != nil { - match := strings.TrimSpace(*body.Name) - for i := range h.cfg.OpenAICompatibility { - if h.cfg.OpenAICompatibility[i].Name == match { - targetIndex = i - break + h.applyConfigMutation(c, func(cfg *config.Config) error { + targetIndex := findIndexByIndexOrMatch(cfg.OpenAICompatibility, body.Index, body.Name, func(item config.OpenAICompatibility, match string) bool { + return item.Name == match + }) + if targetIndex == -1 { + return fmt.Errorf("item not found") + } + entry := cfg.OpenAICompatibility[targetIndex] + if body.Value.Name != nil { + entry.Name = strings.TrimSpace(*body.Value.Name) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + cfg.OpenAICompatibility = append(cfg.OpenAICompatibility[:targetIndex], cfg.OpenAICompatibility[targetIndex+1:]...) + cfg.SanitizeOpenAICompatibility() + return nil } + entry.BaseURL = trimmed } - } - if targetIndex == -1 { - c.JSON(404, gin.H{"error": "item not found"}) - return - } - - entry := h.cfg.OpenAICompatibility[targetIndex] - if body.Value.Name != nil { - entry.Name = strings.TrimSpace(*body.Value.Name) - } - if body.Value.Prefix != nil { - entry.Prefix = strings.TrimSpace(*body.Value.Prefix) - } - if body.Value.BaseURL != nil { - trimmed := strings.TrimSpace(*body.Value.BaseURL) - if trimmed == "" { - h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...) - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return + if body.Value.APIKeyEntries != nil { + entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...) } - entry.BaseURL = trimmed - } - if body.Value.APIKeyEntries != nil { - entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...) - } - if body.Value.Models != nil { - entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...) - } - if body.Value.Headers != nil { - entry.Headers = config.NormalizeHeaders(*body.Value.Headers) - } - normalizeOpenAICompatibilityEntry(&entry) - h.cfg.OpenAICompatibility[targetIndex] = entry - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) + if body.Value.Models != nil { + entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + normalizeOpenAICompatibilityEntry(&entry) + cfg.OpenAICompatibility[targetIndex] = entry + cfg.SanitizeOpenAICompatibility() + return nil + }) } func (h *Handler) DeleteOpenAICompat(c *gin.Context) { if name := c.Query("name"); name != "" { - out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) - for _, v := range h.cfg.OpenAICompatibility { - if v.Name != name { - out = append(out, v) + h.applyConfigMutation(c, func(cfg *config.Config) error { + out := make([]config.OpenAICompatibility, 0, len(cfg.OpenAICompatibility)) + for _, v := range cfg.OpenAICompatibility { + if v.Name != name { + out = append(out, v) + } } - } - h.cfg.OpenAICompatibility = out - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) + cfg.OpenAICompatibility = out + cfg.SanitizeOpenAICompatibility() + return nil + }) return } if idxStr := c.Query("index"); idxStr != "" { var idx int _, err := fmt.Sscanf(idxStr, "%d", &idx) - if err == nil && idx >= 0 && idx < len(h.cfg.OpenAICompatibility) { - h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:idx], h.cfg.OpenAICompatibility[idx+1:]...) - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) + if err == nil { + h.applyConfigMutation(c, func(cfg *config.Config) error { + if idx < 0 || idx >= len(cfg.OpenAICompatibility) { + return fmt.Errorf("missing name or index") + } + cfg.OpenAICompatibility = append(cfg.OpenAICompatibility[:idx], cfg.OpenAICompatibility[idx+1:]...) + cfg.SanitizeOpenAICompatibility() + return nil + }) return } } @@ -488,25 +509,14 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) { // vertex-api-key: []VertexCompatKey func (h *Handler) GetVertexCompatKeys(c *gin.Context) { - c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey}) + cfg := h.configSnapshotOrEmpty() + c.JSON(200, gin.H{"vertex-api-key": cfg.VertexCompatAPIKey}) } func (h *Handler) PutVertexCompatKeys(c *gin.Context) { - data, err := c.GetRawData() - if err != nil { - c.JSON(400, gin.H{"error": "failed to read body"}) + arr, ok := decodeJSONItems[config.VertexCompatKey](c) + if !ok { return } - var arr []config.VertexCompatKey - if err = json.Unmarshal(data, &arr); err != nil { - var obj struct { - Items []config.VertexCompatKey `json:"items"` - } - if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 { - c.JSON(400, gin.H{"error": "invalid body"}) - return - } - arr = obj.Items - } for i := range arr { normalizeVertexCompatKey(&arr[i]) if arr[i].APIKey == "" { @@ -514,9 +524,11 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) { return } } - h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...) - h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...) + cfg.SanitizeVertexCompatKeys() + return nil + }) } func (h *Handler) PatchVertexCompatKey(c *gin.Context) { type vertexCompatPatch struct { @@ -537,88 +549,81 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } - targetIndex := -1 - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) { - targetIndex = *body.Index - } - if targetIndex == -1 && body.Match != nil { - match := strings.TrimSpace(*body.Match) - if match != "" { - for i := range h.cfg.VertexCompatAPIKey { - if h.cfg.VertexCompatAPIKey[i].APIKey == match { - targetIndex = i - break - } + h.applyConfigMutation(c, func(cfg *config.Config) error { + targetIndex := findIndexByIndexOrMatch(cfg.VertexCompatAPIKey, body.Index, body.Match, func(item config.VertexCompatKey, match string) bool { + return item.APIKey == match + }) + if targetIndex == -1 { + return fmt.Errorf("item not found") + } + entry := cfg.VertexCompatAPIKey[targetIndex] + if body.Value.APIKey != nil { + trimmed := strings.TrimSpace(*body.Value.APIKey) + if trimmed == "" { + cfg.VertexCompatAPIKey = append(cfg.VertexCompatAPIKey[:targetIndex], cfg.VertexCompatAPIKey[targetIndex+1:]...) + cfg.SanitizeVertexCompatKeys() + return nil } + entry.APIKey = trimmed } - } - if targetIndex == -1 { - c.JSON(404, gin.H{"error": "item not found"}) - return - } - - entry := h.cfg.VertexCompatAPIKey[targetIndex] - if body.Value.APIKey != nil { - trimmed := strings.TrimSpace(*body.Value.APIKey) - if trimmed == "" { - h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...) - h.cfg.SanitizeVertexCompatKeys() - h.persist(c) - return + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) } - entry.APIKey = trimmed - } - if body.Value.Prefix != nil { - entry.Prefix = strings.TrimSpace(*body.Value.Prefix) - } - if body.Value.BaseURL != nil { - trimmed := strings.TrimSpace(*body.Value.BaseURL) - if trimmed == "" { - h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...) - h.cfg.SanitizeVertexCompatKeys() - h.persist(c) - return + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + cfg.VertexCompatAPIKey = append(cfg.VertexCompatAPIKey[:targetIndex], cfg.VertexCompatAPIKey[targetIndex+1:]...) + cfg.SanitizeVertexCompatKeys() + return nil + } + entry.BaseURL = trimmed } - entry.BaseURL = trimmed - } - if body.Value.ProxyURL != nil { - entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) - } - if body.Value.Headers != nil { - entry.Headers = config.NormalizeHeaders(*body.Value.Headers) - } - if body.Value.Models != nil { - entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...) - } - if body.Value.ExcludedModels != nil { - entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) - } - normalizeVertexCompatKey(&entry) - h.cfg.VertexCompatAPIKey[targetIndex] = entry - h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.Models != nil { + entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + normalizeVertexCompatKey(&entry) + cfg.VertexCompatAPIKey[targetIndex] = entry + cfg.SanitizeVertexCompatKeys() + return nil + }) } func (h *Handler) DeleteVertexCompatKey(c *gin.Context) { if val := strings.TrimSpace(c.Query("api-key")); val != "" { - out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey)) - for _, v := range h.cfg.VertexCompatAPIKey { - if v.APIKey != val { - out = append(out, v) + h.applyConfigMutation(c, func(cfg *config.Config) error { + out := make([]config.VertexCompatKey, 0, len(cfg.VertexCompatAPIKey)) + for _, v := range cfg.VertexCompatAPIKey { + if v.APIKey != val { + out = append(out, v) + } } - } - h.cfg.VertexCompatAPIKey = out - h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + cfg.VertexCompatAPIKey = out + cfg.SanitizeVertexCompatKeys() + return nil + }) return } if idxStr := c.Query("index"); idxStr != "" { var idx int _, errScan := fmt.Sscanf(idxStr, "%d", &idx) - if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) { - h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...) - h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + if errScan == nil { + h.applyConfigMutation(c, func(cfg *config.Config) error { + if idx < 0 || idx >= len(cfg.VertexCompatAPIKey) { + return fmt.Errorf("missing api-key or index") + } + cfg.VertexCompatAPIKey = append(cfg.VertexCompatAPIKey[:idx], cfg.VertexCompatAPIKey[idx+1:]...) + cfg.SanitizeVertexCompatKeys() + return nil + }) return } } @@ -627,7 +632,8 @@ func (h *Handler) DeleteVertexCompatKey(c *gin.Context) { // oauth-excluded-models: map[string][]string func (h *Handler) GetOAuthExcludedModels(c *gin.Context) { - c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)}) + cfg := h.configSnapshotOrEmpty() + c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)}) } func (h *Handler) PutOAuthExcludedModels(c *gin.Context) { @@ -647,8 +653,10 @@ func (h *Handler) PutOAuthExcludedModels(c *gin.Context) { } entries = wrapper.Items } - h.cfg.OAuthExcludedModels = config.NormalizeOAuthExcludedModels(entries) - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + cfg.OAuthExcludedModels = config.NormalizeOAuthExcludedModels(entries) + return nil + }) } func (h *Handler) PatchOAuthExcludedModels(c *gin.Context) { @@ -666,27 +674,26 @@ func (h *Handler) PatchOAuthExcludedModels(c *gin.Context) { return } normalized := config.NormalizeExcludedModels(body.Models) - if len(normalized) == 0 { - if h.cfg.OAuthExcludedModels == nil { - c.JSON(404, gin.H{"error": "provider not found"}) - return - } - if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok { - c.JSON(404, gin.H{"error": "provider not found"}) - return + h.applyConfigMutation(c, func(cfg *config.Config) error { + if len(normalized) == 0 { + if cfg.OAuthExcludedModels == nil { + return fmt.Errorf("provider not found") + } + if _, ok := cfg.OAuthExcludedModels[provider]; !ok { + return fmt.Errorf("provider not found") + } + delete(cfg.OAuthExcludedModels, provider) + if len(cfg.OAuthExcludedModels) == 0 { + cfg.OAuthExcludedModels = nil + } + return nil } - delete(h.cfg.OAuthExcludedModels, provider) - if len(h.cfg.OAuthExcludedModels) == 0 { - h.cfg.OAuthExcludedModels = nil + if cfg.OAuthExcludedModels == nil { + cfg.OAuthExcludedModels = make(map[string][]string) } - h.persist(c) - return - } - if h.cfg.OAuthExcludedModels == nil { - h.cfg.OAuthExcludedModels = make(map[string][]string) - } - h.cfg.OAuthExcludedModels[provider] = normalized - h.persist(c) + cfg.OAuthExcludedModels[provider] = normalized + return nil + }) } func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) { @@ -695,24 +702,25 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) { c.JSON(400, gin.H{"error": "missing provider"}) return } - if h.cfg.OAuthExcludedModels == nil { - c.JSON(404, gin.H{"error": "provider not found"}) - return - } - if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok { - c.JSON(404, gin.H{"error": "provider not found"}) - return - } - delete(h.cfg.OAuthExcludedModels, provider) - if len(h.cfg.OAuthExcludedModels) == 0 { - h.cfg.OAuthExcludedModels = nil - } - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + if cfg.OAuthExcludedModels == nil { + return fmt.Errorf("provider not found") + } + if _, ok := cfg.OAuthExcludedModels[provider]; !ok { + return fmt.Errorf("provider not found") + } + delete(cfg.OAuthExcludedModels, provider) + if len(cfg.OAuthExcludedModels) == 0 { + cfg.OAuthExcludedModels = nil + } + return nil + }) } // oauth-model-alias: map[string][]OAuthModelAlias func (h *Handler) GetOAuthModelAlias(c *gin.Context) { - c.JSON(200, gin.H{"oauth-model-alias": sanitizedOAuthModelAlias(h.cfg.OAuthModelAlias)}) + cfg := h.configSnapshotOrEmpty() + c.JSON(200, gin.H{"oauth-model-alias": sanitizedOAuthModelAlias(cfg.OAuthModelAlias)}) } func (h *Handler) PutOAuthModelAlias(c *gin.Context) { @@ -732,8 +740,10 @@ func (h *Handler) PutOAuthModelAlias(c *gin.Context) { } entries = wrapper.Items } - h.cfg.OAuthModelAlias = sanitizedOAuthModelAlias(entries) - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + cfg.OAuthModelAlias = sanitizedOAuthModelAlias(entries) + return nil + }) } func (h *Handler) PatchOAuthModelAlias(c *gin.Context) { @@ -760,27 +770,26 @@ func (h *Handler) PatchOAuthModelAlias(c *gin.Context) { normalizedMap := sanitizedOAuthModelAlias(map[string][]config.OAuthModelAlias{channel: body.Aliases}) normalized := normalizedMap[channel] - if len(normalized) == 0 { - if h.cfg.OAuthModelAlias == nil { - c.JSON(404, gin.H{"error": "channel not found"}) - return - } - if _, ok := h.cfg.OAuthModelAlias[channel]; !ok { - c.JSON(404, gin.H{"error": "channel not found"}) - return + h.applyConfigMutation(c, func(cfg *config.Config) error { + if len(normalized) == 0 { + if cfg.OAuthModelAlias == nil { + return fmt.Errorf("channel not found") + } + if _, ok := cfg.OAuthModelAlias[channel]; !ok { + return fmt.Errorf("channel not found") + } + delete(cfg.OAuthModelAlias, channel) + if len(cfg.OAuthModelAlias) == 0 { + cfg.OAuthModelAlias = nil + } + return nil } - delete(h.cfg.OAuthModelAlias, channel) - if len(h.cfg.OAuthModelAlias) == 0 { - h.cfg.OAuthModelAlias = nil + if cfg.OAuthModelAlias == nil { + cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias) } - h.persist(c) - return - } - if h.cfg.OAuthModelAlias == nil { - h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias) - } - h.cfg.OAuthModelAlias[channel] = normalized - h.persist(c) + cfg.OAuthModelAlias[channel] = normalized + return nil + }) } func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) { @@ -792,42 +801,31 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) { c.JSON(400, gin.H{"error": "missing channel"}) return } - if h.cfg.OAuthModelAlias == nil { - c.JSON(404, gin.H{"error": "channel not found"}) - return - } - if _, ok := h.cfg.OAuthModelAlias[channel]; !ok { - c.JSON(404, gin.H{"error": "channel not found"}) - return - } - delete(h.cfg.OAuthModelAlias, channel) - if len(h.cfg.OAuthModelAlias) == 0 { - h.cfg.OAuthModelAlias = nil - } - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + if cfg.OAuthModelAlias == nil { + return fmt.Errorf("channel not found") + } + if _, ok := cfg.OAuthModelAlias[channel]; !ok { + return fmt.Errorf("channel not found") + } + delete(cfg.OAuthModelAlias, channel) + if len(cfg.OAuthModelAlias) == 0 { + cfg.OAuthModelAlias = nil + } + return nil + }) } // codex-api-key: []CodexKey func (h *Handler) GetCodexKeys(c *gin.Context) { - c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey}) + cfg := h.configSnapshotOrEmpty() + c.JSON(200, gin.H{"codex-api-key": cfg.CodexKey}) } func (h *Handler) PutCodexKeys(c *gin.Context) { - data, err := c.GetRawData() - if err != nil { - c.JSON(400, gin.H{"error": "failed to read body"}) + arr, ok := decodeJSONItems[config.CodexKey](c) + if !ok { return } - var arr []config.CodexKey - if err = json.Unmarshal(data, &arr); err != nil { - var obj struct { - Items []config.CodexKey `json:"items"` - } - if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 { - c.JSON(400, gin.H{"error": "invalid body"}) - return - } - arr = obj.Items - } // Filter out codex entries with empty base-url (treat as removed) filtered := make([]config.CodexKey, 0, len(arr)) for i := range arr { @@ -838,9 +836,11 @@ func (h *Handler) PutCodexKeys(c *gin.Context) { } filtered = append(filtered, entry) } - h.cfg.CodexKey = filtered - h.cfg.SanitizeCodexKeys() - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + cfg.CodexKey = filtered + cfg.SanitizeCodexKeys() + return nil + }) } func (h *Handler) PatchCodexKey(c *gin.Context) { type codexKeyPatch struct { @@ -861,79 +861,75 @@ func (h *Handler) PatchCodexKey(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } - targetIndex := -1 - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { - targetIndex = *body.Index - } - if targetIndex == -1 && body.Match != nil { - match := strings.TrimSpace(*body.Match) - for i := range h.cfg.CodexKey { - if h.cfg.CodexKey[i].APIKey == match { - targetIndex = i - break + h.applyConfigMutation(c, func(cfg *config.Config) error { + targetIndex := findIndexByIndexOrMatch(cfg.CodexKey, body.Index, body.Match, func(item config.CodexKey, match string) bool { + return item.APIKey == match + }) + if targetIndex == -1 { + return fmt.Errorf("item not found") + } + entry := cfg.CodexKey[targetIndex] + if body.Value.APIKey != nil { + entry.APIKey = strings.TrimSpace(*body.Value.APIKey) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + cfg.CodexKey = append(cfg.CodexKey[:targetIndex], cfg.CodexKey[targetIndex+1:]...) + cfg.SanitizeCodexKeys() + return nil } + entry.BaseURL = trimmed } - } - if targetIndex == -1 { - c.JSON(404, gin.H{"error": "item not found"}) - return - } - - entry := h.cfg.CodexKey[targetIndex] - if body.Value.APIKey != nil { - entry.APIKey = strings.TrimSpace(*body.Value.APIKey) - } - if body.Value.Prefix != nil { - entry.Prefix = strings.TrimSpace(*body.Value.Prefix) - } - if body.Value.BaseURL != nil { - trimmed := strings.TrimSpace(*body.Value.BaseURL) - if trimmed == "" { - h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...) - h.cfg.SanitizeCodexKeys() - h.persist(c) - return + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) } - entry.BaseURL = trimmed - } - if body.Value.ProxyURL != nil { - entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) - } - if body.Value.Models != nil { - entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...) - } - if body.Value.Headers != nil { - entry.Headers = config.NormalizeHeaders(*body.Value.Headers) - } - if body.Value.ExcludedModels != nil { - entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) - } - normalizeCodexKey(&entry) - h.cfg.CodexKey[targetIndex] = entry - h.cfg.SanitizeCodexKeys() - h.persist(c) + if body.Value.Models != nil { + entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + normalizeCodexKey(&entry) + cfg.CodexKey[targetIndex] = entry + cfg.SanitizeCodexKeys() + return nil + }) } func (h *Handler) DeleteCodexKey(c *gin.Context) { if val := c.Query("api-key"); val != "" { - out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) - for _, v := range h.cfg.CodexKey { - if v.APIKey != val { - out = append(out, v) + h.applyConfigMutation(c, func(cfg *config.Config) error { + out := make([]config.CodexKey, 0, len(cfg.CodexKey)) + for _, v := range cfg.CodexKey { + if v.APIKey != val { + out = append(out, v) + } } - } - h.cfg.CodexKey = out - h.cfg.SanitizeCodexKeys() - h.persist(c) + cfg.CodexKey = out + cfg.SanitizeCodexKeys() + return nil + }) return } if idxStr := c.Query("index"); idxStr != "" { var idx int _, err := fmt.Sscanf(idxStr, "%d", &idx) - if err == nil && idx >= 0 && idx < len(h.cfg.CodexKey) { - h.cfg.CodexKey = append(h.cfg.CodexKey[:idx], h.cfg.CodexKey[idx+1:]...) - h.cfg.SanitizeCodexKeys() - h.persist(c) + if err == nil { + h.applyConfigMutation(c, func(cfg *config.Config) error { + if idx < 0 || idx >= len(cfg.CodexKey) { + return fmt.Errorf("missing api-key or index") + } + cfg.CodexKey = append(cfg.CodexKey[:idx], cfg.CodexKey[idx+1:]...) + cfg.SanitizeCodexKeys() + return nil + }) return } } @@ -1074,74 +1070,58 @@ func sanitizedOAuthModelAlias(entries map[string][]config.OAuthModelAlias) map[s // GetAmpCode returns the complete ampcode configuration. func (h *Handler) GetAmpCode(c *gin.Context) { - if h == nil || h.cfg == nil { - c.JSON(200, gin.H{"ampcode": config.AmpCode{}}) - return - } - c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode}) + c.JSON(200, gin.H{"ampcode": h.ampCodeSnapshot()}) } // GetAmpUpstreamURL returns the ampcode upstream URL. func (h *Handler) GetAmpUpstreamURL(c *gin.Context) { - if h == nil || h.cfg == nil { - c.JSON(200, gin.H{"upstream-url": ""}) - return - } - c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL}) + c.JSON(200, gin.H{"upstream-url": h.ampCodeSnapshot().UpstreamURL}) } // PutAmpUpstreamURL updates the ampcode upstream URL. func (h *Handler) PutAmpUpstreamURL(c *gin.Context) { - h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) }) + h.updateStringField(c, func(cfg *config.Config, v string) { cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) }) } // DeleteAmpUpstreamURL clears the ampcode upstream URL. func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) { - h.cfg.AmpCode.UpstreamURL = "" - h.persist(c) + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + amp.UpstreamURL = "" + return nil + }) } // GetAmpUpstreamAPIKey returns the ampcode upstream API key. func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) { - if h == nil || h.cfg == nil { - c.JSON(200, gin.H{"upstream-api-key": ""}) - return - } - c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey}) + c.JSON(200, gin.H{"upstream-api-key": h.ampCodeSnapshot().UpstreamAPIKey}) } // PutAmpUpstreamAPIKey updates the ampcode upstream API key. func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) { - h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) }) + h.updateStringField(c, func(cfg *config.Config, v string) { cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) }) } // DeleteAmpUpstreamAPIKey clears the ampcode upstream API key. func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) { - h.cfg.AmpCode.UpstreamAPIKey = "" - h.persist(c) + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + amp.UpstreamAPIKey = "" + return nil + }) } // GetAmpRestrictManagementToLocalhost returns the localhost restriction setting. func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) { - if h == nil || h.cfg == nil { - c.JSON(200, gin.H{"restrict-management-to-localhost": true}) - return - } - c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost}) + c.JSON(200, gin.H{"restrict-management-to-localhost": h.ampCodeSnapshot().RestrictManagementToLocalhost}) } // PutAmpRestrictManagementToLocalhost updates the localhost restriction setting. func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v }) + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.AmpCode.RestrictManagementToLocalhost = v }) } // GetAmpModelMappings returns the ampcode model mappings. func (h *Handler) GetAmpModelMappings(c *gin.Context) { - if h == nil || h.cfg == nil { - c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}}) - return - } - c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings}) + c.JSON(200, gin.H{"model-mappings": h.ampCodeSnapshot().ModelMappings}) } // PutAmpModelMappings replaces all ampcode model mappings. @@ -1153,8 +1133,10 @@ func (h *Handler) PutAmpModelMappings(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } - h.cfg.AmpCode.ModelMappings = body.Value - h.persist(c) + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + amp.ModelMappings = body.Value + return nil + }) } // PatchAmpModelMappings adds or updates model mappings. @@ -1167,70 +1149,71 @@ func (h *Handler) PatchAmpModelMappings(c *gin.Context) { return } - existing := make(map[string]int) - for i, m := range h.cfg.AmpCode.ModelMappings { - existing[strings.TrimSpace(m.From)] = i - } - - for _, newMapping := range body.Value { - from := strings.TrimSpace(newMapping.From) - if idx, ok := existing[from]; ok { - h.cfg.AmpCode.ModelMappings[idx] = newMapping - } else { - h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping) - existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1 + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + existing := make(map[string]int) + for i, m := range amp.ModelMappings { + existing[strings.TrimSpace(m.From)] = i } - } - h.persist(c) + for _, newMapping := range body.Value { + from := strings.TrimSpace(newMapping.From) + if idx, ok := existing[from]; ok { + amp.ModelMappings[idx] = newMapping + } else { + amp.ModelMappings = append(amp.ModelMappings, newMapping) + existing[from] = len(amp.ModelMappings) - 1 + } + } + return nil + }) } // DeleteAmpModelMappings removes specified model mappings by "from" field. func (h *Handler) DeleteAmpModelMappings(c *gin.Context) { var body struct { - Value []string `json:"value"` + Value *[]string `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return } - if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 { - h.cfg.AmpCode.ModelMappings = nil - h.persist(c) + if len(*body.Value) == 0 { + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + amp.ModelMappings = nil + return nil + }) return } toRemove := make(map[string]bool) - for _, from := range body.Value { + for _, from := range *body.Value { toRemove[strings.TrimSpace(from)] = true } - newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings)) - for _, m := range h.cfg.AmpCode.ModelMappings { - if !toRemove[strings.TrimSpace(m.From)] { - newMappings = append(newMappings, m) + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + newMappings := make([]config.AmpModelMapping, 0, len(amp.ModelMappings)) + for _, m := range amp.ModelMappings { + if !toRemove[strings.TrimSpace(m.From)] { + newMappings = append(newMappings, m) + } } - } - h.cfg.AmpCode.ModelMappings = newMappings - h.persist(c) + amp.ModelMappings = newMappings + return nil + }) } // GetAmpForceModelMappings returns whether model mappings are forced. func (h *Handler) GetAmpForceModelMappings(c *gin.Context) { - if h == nil || h.cfg == nil { - c.JSON(200, gin.H{"force-model-mappings": false}) - return - } - c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings}) + c.JSON(200, gin.H{"force-model-mappings": h.ampCodeSnapshot().ForceModelMappings}) } // PutAmpForceModelMappings updates the force model mappings setting. func (h *Handler) PutAmpForceModelMappings(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.AmpCode.ForceModelMappings = v }) } // GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping. func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) { - if h == nil || h.cfg == nil { - c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}}) - return - } - c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys}) + c.JSON(200, gin.H{"upstream-api-keys": h.ampCodeSnapshot().UpstreamAPIKeys}) } // PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings. @@ -1244,8 +1227,10 @@ func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) { } // Normalize entries: trim whitespace, filter empty normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value) - h.cfg.AmpCode.UpstreamAPIKeys = normalized - h.persist(c) + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + amp.UpstreamAPIKeys = normalized + return nil + }) } // PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries. @@ -1259,28 +1244,29 @@ func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) { return } - existing := make(map[string]int) - for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys { - existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i - } - - for _, newEntry := range body.Value { - upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey) - if upstreamKey == "" { - continue - } - normalizedEntry := config.AmpUpstreamAPIKeyEntry{ - UpstreamAPIKey: upstreamKey, - APIKeys: normalizeAPIKeysList(newEntry.APIKeys), + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + existing := make(map[string]int) + for i, entry := range amp.UpstreamAPIKeys { + existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i } - if idx, ok := existing[upstreamKey]; ok { - h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry - } else { - h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry) - existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1 + for _, newEntry := range body.Value { + upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey) + if upstreamKey == "" { + continue + } + normalizedEntry := config.AmpUpstreamAPIKeyEntry{ + UpstreamAPIKey: upstreamKey, + APIKeys: normalizeAPIKeysList(newEntry.APIKeys), + } + if idx, ok := existing[upstreamKey]; ok { + amp.UpstreamAPIKeys[idx] = normalizedEntry + } else { + amp.UpstreamAPIKeys = append(amp.UpstreamAPIKeys, normalizedEntry) + existing[upstreamKey] = len(amp.UpstreamAPIKeys) - 1 + } } - } - h.persist(c) + return nil + }) } // DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries. @@ -1303,8 +1289,10 @@ func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) { // Empty array means clear all if len(body.Value) == 0 { - h.cfg.AmpCode.UpstreamAPIKeys = nil - h.persist(c) + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + amp.UpstreamAPIKeys = nil + return nil + }) return } @@ -1321,14 +1309,16 @@ func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) { return } - newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys)) - for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys { - if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] { - newEntries = append(newEntries, entry) + h.mutateAmpCode(c, func(amp *config.AmpCode) error { + newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(amp.UpstreamAPIKeys)) + for _, entry := range amp.UpstreamAPIKeys { + if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] { + newEntries = append(newEntries, entry) + } } - } - h.cfg.AmpCode.UpstreamAPIKeys = newEntries - h.persist(c) + amp.UpstreamAPIKeys = newEntries + return nil + }) } // normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries. diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index 45786b9d3e..09237f548e 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -4,6 +4,7 @@ package management import ( "crypto/subtle" + "errors" "fmt" "net/http" "os" @@ -19,6 +20,7 @@ import ( sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "golang.org/x/crypto/bcrypt" + "gopkg.in/yaml.v3" ) type attemptInfo struct { @@ -27,6 +29,26 @@ type attemptInfo struct { lastActivity time.Time // track last activity for cleanup } +type RuntimeApplier func(*config.Config) error + +type runtimeApplyError struct { + cause error +} + +func (e *runtimeApplyError) Error() string { + if e == nil || e.cause == nil { + return "runtime apply failed" + } + return e.cause.Error() +} + +func (e *runtimeApplyError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + // attemptCleanupInterval controls how often stale IP entries are purged const attemptCleanupInterval = 1 * time.Hour @@ -38,6 +60,7 @@ type Handler struct { cfg *config.Config configFilePath string mu sync.Mutex + stateMu sync.RWMutex attemptsMu sync.Mutex failedAttempts map[string]*attemptInfo // keyed by client IP authManager *coreauth.Manager @@ -48,6 +71,36 @@ type Handler struct { envSecret string logDir string postAuthHook coreauth.PostAuthHook + runtimeApplier RuntimeApplier +} + +type runtimeStateSnapshot struct { + cfg *config.Config + configFilePath string + authManager *coreauth.Manager + usageStats *usage.RequestStatistics + tokenStore coreauth.Store + localPassword string + allowRemoteOverride bool + envSecret string + logDir string + postAuthHook coreauth.PostAuthHook + runtimeApplier RuntimeApplier +} + +type authDirProvider interface { + AuthDir() string +} + +// ResolveEffectiveAuthDir returns the runtime auth directory used by file-based flows. +// Stores with mirrored workspaces may override the configured auth dir. +func ResolveEffectiveAuthDir(configAuthDir string, store coreauth.Store) string { + if provider, ok := store.(authDirProvider); ok { + if dir := strings.TrimSpace(provider.AuthDir()); dir != "" { + return dir + } + } + return strings.TrimSpace(configAuthDir) } // NewHandler creates a new management handler instance. @@ -104,17 +157,41 @@ func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manag return NewHandler(cfg, "", manager) } +// StateMiddleware is kept for compatibility but no longer serializes the entire +// request lifecycle. Handlers should use short-lived snapshots and mutation helpers instead. +func (h *Handler) StateMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + c.Next() + } +} + // SetConfig updates the in-memory config reference when the server hot-reloads. -func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg } +func (h *Handler) SetConfig(cfg *config.Config) { + h.stateMu.Lock() + h.cfg = cfg + h.stateMu.Unlock() +} // SetAuthManager updates the auth manager reference used by management endpoints. -func (h *Handler) SetAuthManager(manager *coreauth.Manager) { h.authManager = manager } +func (h *Handler) SetAuthManager(manager *coreauth.Manager) { + h.stateMu.Lock() + h.authManager = manager + h.stateMu.Unlock() +} // SetUsageStatistics allows replacing the usage statistics reference. -func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats } +func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { + h.stateMu.Lock() + h.usageStats = stats + h.stateMu.Unlock() +} // SetLocalPassword configures the runtime-local password accepted for localhost requests. -func (h *Handler) SetLocalPassword(password string) { h.localPassword = password } +func (h *Handler) SetLocalPassword(password string) { + h.stateMu.Lock() + h.localPassword = password + h.stateMu.Unlock() +} // SetLogDirectory updates the directory where main.log should be looked up. func (h *Handler) SetLogDirectory(dir string) { @@ -126,12 +203,100 @@ func (h *Handler) SetLogDirectory(dir string) { dir = abs } } + h.stateMu.Lock() h.logDir = dir + h.stateMu.Unlock() } // SetPostAuthHook registers a hook to be called after auth record creation but before persistence. func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) { + h.stateMu.Lock() h.postAuthHook = hook + h.stateMu.Unlock() +} + +// SetRuntimeApplier registers the synchronous runtime config apply callback. +func (h *Handler) SetRuntimeApplier(applier RuntimeApplier) { + h.stateMu.Lock() + h.runtimeApplier = applier + h.stateMu.Unlock() +} + +func cloneConfig(cfg *config.Config) (*config.Config, error) { + if cfg == nil { + return nil, nil + } + raw, err := yaml.Marshal(cfg) + if err != nil { + return nil, err + } + var cloned config.Config + if err := yaml.Unmarshal(raw, &cloned); err != nil { + return nil, err + } + return &cloned, nil +} + +func (h *Handler) runtimeSnapshot() (*runtimeStateSnapshot, error) { + if h == nil { + return &runtimeStateSnapshot{}, nil + } + h.stateMu.RLock() + cfg, err := cloneConfig(h.cfg) + snapshot := &runtimeStateSnapshot{ + cfg: cfg, + configFilePath: h.configFilePath, + authManager: h.authManager, + usageStats: h.usageStats, + tokenStore: h.tokenStore, + localPassword: h.localPassword, + allowRemoteOverride: h.allowRemoteOverride, + envSecret: h.envSecret, + logDir: h.logDir, + postAuthHook: h.postAuthHook, + runtimeApplier: h.runtimeApplier, + } + h.stateMu.RUnlock() + if snapshot.tokenStore == nil { + snapshot.tokenStore = sdkAuth.GetTokenStore() + } + if err != nil { + return snapshot, err + } + return snapshot, nil +} + +func (h *Handler) configSnapshot() (*config.Config, error) { + snapshot, err := h.runtimeSnapshot() + if err != nil { + return nil, err + } + return snapshot.cfg, nil +} + +func effectiveAuthDirFromSnapshot(cfg *config.Config, store coreauth.Store) string { + if cfg == nil { + return ResolveEffectiveAuthDir("", store) + } + return ResolveEffectiveAuthDir(cfg.AuthDir, store) +} + +func (h *Handler) effectiveAuthDir() string { + snapshot, err := h.runtimeSnapshot() + if err != nil { + return "" + } + return effectiveAuthDirFromSnapshot(snapshot.cfg, snapshot.tokenStore) +} + +func (h *Handler) registerOAuthSession(state, provider string) string { + snapshot, err := h.runtimeSnapshot() + if err != nil { + return "" + } + authDir := effectiveAuthDirFromSnapshot(snapshot.cfg, snapshot.tokenStore) + RegisterOAuthSession(state, provider, authDir) + return authDir } // Middleware enforces access control for management endpoints. @@ -148,7 +313,12 @@ func (h *Handler) Middleware() gin.HandlerFunc { clientIP := c.ClientIP() localClient := clientIP == "127.0.0.1" || clientIP == "::1" + h.stateMu.RLock() cfg := h.cfg + localPassword := h.localPassword + allowRemoteOverride := h.allowRemoteOverride + envSecret := h.envSecret + h.stateMu.RUnlock() var ( allowRemote bool secretHash string @@ -157,10 +327,9 @@ func (h *Handler) Middleware() gin.HandlerFunc { allowRemote = cfg.RemoteManagement.AllowRemote secretHash = cfg.RemoteManagement.SecretKey } - if h.allowRemoteOverride { + if allowRemoteOverride { allowRemote = true } - envSecret := h.envSecret fail := func() {} if !localClient { @@ -230,8 +399,8 @@ func (h *Handler) Middleware() gin.HandlerFunc { } if localClient { - if lp := h.localPassword; lp != "" { - if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 { + if localPassword != "" { + if subtle.ConstantTimeCompare([]byte(provided), []byte(localPassword)) == 1 { c.Next() return } @@ -272,21 +441,98 @@ func (h *Handler) Middleware() gin.HandlerFunc { } } +func (h *Handler) applyRuntimeConfig(snapshot *runtimeStateSnapshot, cfg *config.Config) error { + if snapshot == nil || snapshot.runtimeApplier == nil || cfg == nil { + return nil + } + if err := snapshot.runtimeApplier(cfg); err != nil { + return &runtimeApplyError{cause: err} + } + return nil +} + +func (h *Handler) reloadCommittedConfig(snapshot *runtimeStateSnapshot) (*config.Config, error) { + if snapshot == nil { + return nil, fmt.Errorf("failed to snapshot config: snapshot unavailable") + } + committedCfg, err := config.LoadConfig(snapshot.configFilePath) + if err != nil { + return nil, fmt.Errorf("failed to reload config: %w", err) + } + if err := h.applyRuntimeConfig(snapshot, committedCfg); err != nil { + var applyErr *runtimeApplyError + if errors.As(err, &applyErr) && applyErr != nil && applyErr.cause != nil { + err = applyErr.cause + } + return nil, fmt.Errorf("failed to apply runtime config: %w", err) + } + h.stateMu.Lock() + h.cfg = committedCfg + h.stateMu.Unlock() + return committedCfg, nil +} + +func (h *Handler) commitConfig(cfg *config.Config) (*config.Config, error) { + if cfg == nil { + return nil, fmt.Errorf("configuration unavailable") + } + snapshot, err := h.runtimeSnapshot() + if err != nil { + return nil, fmt.Errorf("failed to snapshot config: %w", err) + } + if err := config.SaveConfigPreserveComments(snapshot.configFilePath, cfg); err != nil { + return nil, fmt.Errorf("failed to save config: %w", err) + } + return h.reloadCommittedConfig(snapshot) +} + +func (h *Handler) persistConfig(c *gin.Context, cfg *config.Config) bool { + if _, err := h.commitConfig(cfg); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return false + } + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + return true +} + // persist saves the current in-memory config to disk. func (h *Handler) persist(c *gin.Context) bool { h.mu.Lock() defer h.mu.Unlock() - // Preserve comments when writing - if err := config.SaveConfigPreserveComments(h.configFilePath, h.cfg); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save config: %v", err)}) + cfg, err := h.configSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to snapshot config: %v", err)}) return false } - c.JSON(http.StatusOK, gin.H{"status": "ok"}) - return true + return h.persistConfig(c, cfg) +} + +func (h *Handler) applyConfigMutation(c *gin.Context, mutate func(*config.Config) error) bool { + h.mu.Lock() + defer h.mu.Unlock() + + current, err := h.configSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to snapshot config: %v", err)}) + return false + } + if current == nil { + current = &config.Config{} + } + nextCfg, err := cloneConfig(current) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to clone config: %v", err)}) + return false + } + if err := mutate(nextCfg); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return false + } + return h.persistConfig(c, nextCfg) } // Helper methods for simple types -func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) { +func (h *Handler) updateBoolField(c *gin.Context, set func(*config.Config, bool)) { var body struct { Value *bool `json:"value"` } @@ -294,11 +540,13 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) return } - set(*body.Value) - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + set(cfg, *body.Value) + return nil + }) } -func (h *Handler) updateIntField(c *gin.Context, set func(int)) { +func (h *Handler) updateIntField(c *gin.Context, set func(*config.Config, int) error) { var body struct { Value *int `json:"value"` } @@ -306,11 +554,12 @@ func (h *Handler) updateIntField(c *gin.Context, set func(int)) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) return } - set(*body.Value) - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + return set(cfg, *body.Value) + }) } -func (h *Handler) updateStringField(c *gin.Context, set func(string)) { +func (h *Handler) updateStringField(c *gin.Context, set func(*config.Config, string)) { var body struct { Value *string `json:"value"` } @@ -318,6 +567,8 @@ func (h *Handler) updateStringField(c *gin.Context, set func(string)) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) return } - set(*body.Value) - h.persist(c) + h.applyConfigMutation(c, func(cfg *config.Config) error { + set(cfg, *body.Value) + return nil + }) } diff --git a/internal/api/handlers/management/handler_state_test.go b/internal/api/handlers/management/handler_state_test.go new file mode 100644 index 0000000000..e282b6f1c4 --- /dev/null +++ b/internal/api/handlers/management/handler_state_test.go @@ -0,0 +1,285 @@ +package management + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "golang.org/x/crypto/bcrypt" +) + +func TestStateMiddleware_DoesNotBlockConcurrentSetConfig(t *testing.T) { + gin.SetMode(gin.TestMode) + + hash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.DefaultCost) + if err != nil { + t.Fatalf("generate hash: %v", err) + } + + h := NewHandler(&config.Config{RemoteManagement: config.RemoteManagement{SecretKey: string(hash)}}, filepath.Join(t.TempDir(), "config.yaml"), nil) + r := gin.New() + release := make(chan struct{}) + reached := make(chan struct{}) + r.GET("/guarded", h.StateMiddleware(), func(c *gin.Context) { + close(reached) + <-release + c.Status(http.StatusNoContent) + }) + + go func() { + req := httptest.NewRequest(http.MethodGet, "/guarded", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + }() + + select { + case <-reached: + case <-time.After(2 * time.Second): + t.Fatal("request did not enter guarded handler") + } + + updated := make(chan struct{}) + go func() { + h.SetConfig(&config.Config{}) + close(updated) + }() + + select { + case <-updated: + case <-time.After(150 * time.Millisecond): + t.Fatal("SetConfig should not wait for in-flight management request") + } + + close(release) + + select { + case <-updated: + case <-time.After(2 * time.Second): + t.Fatal("SetConfig did not complete after request finished") + } +} + +func TestPutLogsMaxTotalSizeMB_RejectsOversizedValue(t *testing.T) { + gin.SetMode(gin.TestMode) + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("logging-to-file: false\n"), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + h := NewHandler(&config.Config{}, configPath, nil) + r := gin.New() + r.PUT("/logs-max-total-size-mb", h.StateMiddleware(), h.PutLogsMaxTotalSizeMB) + + body := strings.NewReader(`{"value":1048577}`) + req := httptest.NewRequest(http.MethodPut, "/logs-max-total-size-mb", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: got %d body=%s", w.Code, w.Body.String()) + } + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if !strings.Contains(resp["error"], "exceeds allowed maximum") { + t.Fatalf("unexpected error response: %v", resp) + } +} + +func TestStateMiddleware_DoesNotDeadlockRegisterOAuthSession(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewHandler(&config.Config{AuthDir: t.TempDir()}, filepath.Join(t.TempDir(), "config.yaml"), nil) + r := gin.New() + r.GET("/oauth", h.StateMiddleware(), func(c *gin.Context) { + h.registerOAuthSession("state-for-deadlock-check", "codex") + c.Status(http.StatusNoContent) + }) + + done := make(chan struct{}) + go func() { + req := httptest.NewRequest(http.MethodGet, "/oauth", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + close(done) + }() + + select { + case <-done: + case <-time.After(300 * time.Millisecond): + t.Fatal("registerOAuthSession should not block while management request is in flight") + } +} + +func TestApplyConfigMutation_AppliesPersistedConfigViaRuntimeApplier(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("request-log: false\n"), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + h := NewHandler(&config.Config{}, configPath, nil) + applied := 0 + applierSawPersisted := false + h.SetRuntimeApplier(func(cfg *config.Config) error { + applied++ + if cfg != nil && cfg.RequestLog { + persisted, err := os.ReadFile(configPath) + if err == nil && strings.Contains(string(persisted), "request-log: true") { + applierSawPersisted = true + } + } + return nil + }) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + if !h.applyConfigMutation(ctx, func(cfg *config.Config) error { + cfg.RequestLog = true + return nil + }) { + t.Fatalf("expected applyConfigMutation to succeed, got %d body=%s", recorder.Code, recorder.Body.String()) + } + if recorder.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d body=%s", http.StatusOK, recorder.Code, recorder.Body.String()) + } + if applied != 1 { + t.Fatalf("expected runtime applier to be called once, got %d", applied) + } + if !applierSawPersisted { + t.Fatal("expected runtime applier to observe committed config on disk") + } + + snapshot, err := h.runtimeSnapshot() + if err != nil { + t.Fatalf("runtime snapshot: %v", err) + } + if snapshot.cfg == nil || !snapshot.cfg.RequestLog { + t.Fatalf("expected runtime snapshot config to include request-log=true, got %+v", snapshot.cfg) + } +} + +func TestPutConfigYAML_ClampsOversizedLogLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("logging-to-file: false\n"), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + h := NewHandler(&config.Config{}, configPath, nil) + applied := 0 + applierSawClamped := false + applierSawPersisted := false + h.SetRuntimeApplier(func(cfg *config.Config) error { + applied++ + if cfg != nil && cfg.LogsMaxTotalSizeMB == config.MaxLogsMaxTotalSizeMB { + applierSawClamped = true + } + persisted, err := os.ReadFile(configPath) + if err == nil && strings.Contains(string(persisted), "logs-max-total-size-mb: 1024") { + applierSawPersisted = true + } + return nil + }) + + r := gin.New() + r.PUT("/config.yaml", h.PutConfigYAML) + + req := httptest.NewRequest(http.MethodPut, "/config.yaml", strings.NewReader("logs-max-total-size-mb: 1048577\n")) + req.Header.Set("Content-Type", "application/yaml") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected oversized config to be clamped, got %d body=%s", w.Code, w.Body.String()) + } + if applied != 1 { + t.Fatalf("expected runtime applier to be called once, got %d", applied) + } + if !applierSawClamped { + t.Fatal("expected runtime applier to receive clamped committed config") + } + if !applierSawPersisted { + t.Fatal("expected runtime applier to observe committed config on disk") + } + + snapshot, err := h.runtimeSnapshot() + if err != nil { + t.Fatalf("runtime snapshot: %v", err) + } + if snapshot.cfg == nil { + t.Fatal("expected runtime snapshot config to be available") + } + if snapshot.cfg.LogsMaxTotalSizeMB != config.MaxLogsMaxTotalSizeMB { + t.Fatalf("expected logs-max-total-size-mb to be clamped to %d, got %d", config.MaxLogsMaxTotalSizeMB, snapshot.cfg.LogsMaxTotalSizeMB) + } + + persisted, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("read persisted config: %v", err) + } + persistedText := string(persisted) + if strings.Contains(persistedText, "1048577") { + t.Fatalf("expected persisted config to remove oversized value, got %s", persistedText) + } + if !strings.Contains(persistedText, "logs-max-total-size-mb: 1024") { + t.Fatalf("expected persisted config to contain clamped value, got %s", persistedText) + } +} + +func TestApplyConfigMutation_RuntimeApplierMayUpdateHandlerState(t *testing.T) { + gin.SetMode(gin.TestMode) + + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("request-log: false\n"), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + h := NewHandler(&config.Config{}, configPath, nil) + done := make(chan struct{}) + h.SetRuntimeApplier(func(cfg *config.Config) error { + h.SetConfig(cfg) + close(done) + return nil + }) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + finished := make(chan struct{}) + go func() { + defer close(finished) + if !h.applyConfigMutation(ctx, func(cfg *config.Config) error { + cfg.RequestLog = true + return nil + }) { + t.Errorf("expected applyConfigMutation to succeed, got %d body=%s", recorder.Code, recorder.Body.String()) + } + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("runtime applier should be able to update handler state without deadlocking") + } + + select { + case <-finished: + case <-time.After(2 * time.Second): + t.Fatal("applyConfigMutation should finish after runtime applier updates handler state") + } +} diff --git a/internal/api/handlers/management/logs.go b/internal/api/handlers/management/logs.go index b64cd61938..d38e50e601 100644 --- a/internal/api/handlers/management/logs.go +++ b/internal/api/handlers/management/logs.go @@ -28,11 +28,16 @@ func (h *Handler) GetLogs(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"}) return } - if h.cfg == nil { + cfg, err := h.configSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } + if cfg == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"}) return } - if !h.cfg.LoggingToFile { + if !cfg.LoggingToFile { c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"}) return } @@ -90,11 +95,16 @@ func (h *Handler) DeleteLogs(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"}) return } - if h.cfg == nil { + cfg, err := h.configSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } + if cfg == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"}) return } - if !h.cfg.LoggingToFile { + if !cfg.LoggingToFile { c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"}) return } @@ -152,11 +162,16 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"}) return } - if h.cfg == nil { + cfg, err := h.configSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } + if cfg == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"}) return } - if h.cfg.RequestLog { + if cfg.RequestLog { c.JSON(http.StatusOK, gin.H{"files": []any{}}) return } @@ -216,7 +231,12 @@ func (h *Handler) GetRequestLogByID(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"}) return } - if h.cfg == nil { + cfg, err := h.configSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } + if cfg == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"}) return } @@ -303,7 +323,12 @@ func (h *Handler) DownloadRequestErrorLog(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"}) return } - if h.cfg == nil { + cfg, err := h.configSnapshot() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "configuration unavailable"}) + return + } + if cfg == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"}) return } @@ -357,10 +382,14 @@ func (h *Handler) logDirectory() string { if h == nil { return "" } - if h.logDir != "" { - return h.logDir + snapshot, err := h.runtimeSnapshot() + if err == nil { + if strings.TrimSpace(snapshot.logDir) != "" { + return snapshot.logDir + } + return logging.ResolveLogDirectory(snapshot.cfg) } - return logging.ResolveLogDirectory(h.cfg) + return "" } func (h *Handler) collectLogFiles(dir string) ([]string, error) { diff --git a/internal/api/handlers/management/oauth_callback.go b/internal/api/handlers/management/oauth_callback.go index c69a332ee7..fcd0b93dfa 100644 --- a/internal/api/handlers/management/oauth_callback.go +++ b/internal/api/handlers/management/oauth_callback.go @@ -18,7 +18,7 @@ type oauthCallbackRequest struct { } func (h *Handler) PostOAuthCallback(c *gin.Context) { - if h == nil || h.cfg == nil { + if h == nil { c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"}) return } @@ -87,11 +87,12 @@ func (h *Handler) PostOAuthCallback(c *gin.Context) { return } - if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil { + if _, errWrite := WriteOAuthCallbackFileForPendingSession(canonicalProvider, state, code, errMsg); errWrite != nil { if errors.Is(errWrite, errOAuthSessionNotPending) { c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"}) return } + SetOAuthSessionError(state, errMsg) c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"}) return } diff --git a/internal/api/handlers/management/oauth_callback_test.go b/internal/api/handlers/management/oauth_callback_test.go new file mode 100644 index 0000000000..c99e363f07 --- /dev/null +++ b/internal/api/handlers/management/oauth_callback_test.go @@ -0,0 +1,106 @@ +package management + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestPostOAuthCallback_UsesSessionAuthDir(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + tempDir := t.TempDir() + configAuthDir := filepath.Join(tempDir, "config-auth") + sessionAuthDir := filepath.Join(tempDir, "session-auth") + if err := os.MkdirAll(configAuthDir, 0o700); err != nil { + t.Fatalf("failed to create config auth dir: %v", err) + } + if err := os.MkdirAll(sessionAuthDir, 0o700); err != nil { + t.Fatalf("failed to create session auth dir: %v", err) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: configAuthDir}, coreauth.NewManager(nil, nil, nil)) + state := "codex-session-bound-state" + RegisterOAuthSession(state, "codex", sessionAuthDir) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/v0/management/oauth/callback", strings.NewReader(`{"provider":"codex","state":"`+state+`","code":"auth-code"}`)) + req.Header.Set("Content-Type", "application/json") + ctx.Request = req + + h.PostOAuthCallback(ctx) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, recorder.Code, recorder.Body.String()) + } + + sessionPath := filepath.Join(sessionAuthDir, ".oauth-codex-"+state+".oauth") + configPath := filepath.Join(configAuthDir, ".oauth-codex-"+state+".oauth") + if _, err := os.Stat(sessionPath); err != nil { + t.Fatalf("expected callback file in session auth dir, stat err: %v", err) + } + if _, err := os.Stat(configPath); !os.IsNotExist(err) { + t.Fatalf("expected no callback file in config auth dir, stat err: %v", err) + } + + var payload map[string]string + raw, err := os.ReadFile(sessionPath) + if err != nil { + t.Fatalf("failed to read callback file: %v", err) + } + if err := json.Unmarshal(raw, &payload); err != nil { + t.Fatalf("failed to decode callback file: %v", err) + } + if payload["code"] != "auth-code" { + t.Fatalf("expected code to be persisted, got %q", payload["code"]) + } +} + +func TestPostOAuthCallback_WriteFailureMarksSessionError(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + blockedAuthDir := filepath.Join(t.TempDir(), "blocked-auth-dir") + if err := os.WriteFile(blockedAuthDir, []byte("blocked"), 0o600); err != nil { + t.Fatalf("failed to create blocking auth path: %v", err) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{}, coreauth.NewManager(nil, nil, nil)) + state := "codex-session-write-failure" + RegisterOAuthSession(state, "codex", blockedAuthDir) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/v0/management/oauth/callback", strings.NewReader(`{"provider":"codex","state":"`+state+`","code":"auth-code"}`)) + req.Header.Set("Content-Type", "application/json") + ctx.Request = req + + h.PostOAuthCallback(ctx) + + if recorder.Code != http.StatusInternalServerError { + t.Fatalf("expected status %d, got %d with body %s", http.StatusInternalServerError, recorder.Code, recorder.Body.String()) + } + if IsOAuthSessionPending(state, "codex") { + t.Fatal("expected failed callback write to end pending oauth session") + } + provider, status, ok := GetOAuthSession(state) + if !ok { + t.Fatal("expected oauth session to remain available with error status") + } + if provider != "codex" { + t.Fatalf("expected provider codex, got %q", provider) + } + if strings.TrimSpace(status) == "" { + t.Fatal("expected oauth session error status to be recorded") + } +} diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go index 05ff8d1f52..f26f3f89a0 100644 --- a/internal/api/handlers/management/oauth_sessions.go +++ b/internal/api/handlers/management/oauth_sessions.go @@ -25,6 +25,7 @@ var ( type oauthSession struct { Provider string Status string + AuthDir string CreatedAt time.Time ExpiresAt time.Time } @@ -53,9 +54,10 @@ func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) { } } -func (s *oauthSessionStore) Register(state, provider string) { +func (s *oauthSessionStore) Register(state, provider, authDir string) { state = strings.TrimSpace(state) provider = strings.ToLower(strings.TrimSpace(provider)) + authDir = strings.TrimSpace(authDir) if state == "" || provider == "" { return } @@ -68,6 +70,7 @@ func (s *oauthSessionStore) Register(state, provider string) { s.sessions[state] = oauthSession{ Provider: provider, Status: "", + AuthDir: authDir, CreatedAt: now, ExpiresAt: now.Add(s.ttl), } @@ -168,7 +171,9 @@ func (s *oauthSessionStore) IsPending(state, provider string) bool { var oauthSessions = newOAuthSessionStore(oauthSessionTTL) -func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) } +func RegisterOAuthSession(state, provider, authDir string) { + oauthSessions.Register(state, provider, authDir) +} func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) } @@ -265,19 +270,26 @@ func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) if err != nil { return "", fmt.Errorf("marshal oauth callback payload: %w", err) } + if err := os.MkdirAll(authDir, 0o700); err != nil { + return "", fmt.Errorf("create oauth callback dir: %w", err) + } if err := os.WriteFile(filePath, data, 0o600); err != nil { return "", fmt.Errorf("write oauth callback file: %w", err) } return filePath, nil } -func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) { +func WriteOAuthCallbackFileForPendingSession(provider, state, code, errorMessage string) (string, error) { canonicalProvider, err := NormalizeOAuthProvider(provider) if err != nil { return "", err } - if !IsOAuthSessionPending(state, canonicalProvider) { + session, ok := oauthSessions.Get(state) + if !ok || session.Status != "" || !strings.EqualFold(session.Provider, canonicalProvider) { return "", errOAuthSessionNotPending } - return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage) + if strings.TrimSpace(session.AuthDir) == "" { + return "", fmt.Errorf("oauth session auth dir is empty") + } + return WriteOAuthCallbackFile(session.AuthDir, canonicalProvider, state, code, errorMessage) } diff --git a/internal/api/handlers/management/quota.go b/internal/api/handlers/management/quota.go index c7efd217bd..ba58e43b9b 100644 --- a/internal/api/handlers/management/quota.go +++ b/internal/api/handlers/management/quota.go @@ -1,18 +1,31 @@ package management -import "github.com/gin-gonic/gin" +import ( + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) // Quota exceeded toggles func (h *Handler) GetSwitchProject(c *gin.Context) { - c.JSON(200, gin.H{"switch-project": h.cfg.QuotaExceeded.SwitchProject}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(200, gin.H{"switch-project": false}) + return + } + c.JSON(200, gin.H{"switch-project": cfg.QuotaExceeded.SwitchProject}) } func (h *Handler) PutSwitchProject(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchProject = v }) + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.QuotaExceeded.SwitchProject = v }) } func (h *Handler) GetSwitchPreviewModel(c *gin.Context) { - c.JSON(200, gin.H{"switch-preview-model": h.cfg.QuotaExceeded.SwitchPreviewModel}) + cfg, err := h.configSnapshot() + if err != nil || cfg == nil { + c.JSON(200, gin.H{"switch-preview-model": false}) + return + } + c.JSON(200, gin.H{"switch-preview-model": cfg.QuotaExceeded.SwitchPreviewModel}) } func (h *Handler) PutSwitchPreviewModel(c *gin.Context) { - h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchPreviewModel = v }) + h.updateBoolField(c, func(cfg *config.Config, v bool) { cfg.QuotaExceeded.SwitchPreviewModel = v }) } diff --git a/internal/api/handlers/management/usage.go b/internal/api/handlers/management/usage.go index 5f79408963..be01cd8f51 100644 --- a/internal/api/handlers/management/usage.go +++ b/internal/api/handlers/management/usage.go @@ -23,8 +23,11 @@ type usageImportPayload struct { // GetUsageStatistics returns the in-memory request statistics snapshot. func (h *Handler) GetUsageStatistics(c *gin.Context) { var snapshot usage.StatisticsSnapshot - if h != nil && h.usageStats != nil { - snapshot = h.usageStats.Snapshot() + if h != nil { + state, err := h.runtimeSnapshot() + if err == nil && state.usageStats != nil { + snapshot = state.usageStats.Snapshot() + } } c.JSON(http.StatusOK, gin.H{ "usage": snapshot, @@ -35,8 +38,11 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) { // ExportUsageStatistics returns a complete usage snapshot for backup/migration. func (h *Handler) ExportUsageStatistics(c *gin.Context) { var snapshot usage.StatisticsSnapshot - if h != nil && h.usageStats != nil { - snapshot = h.usageStats.Snapshot() + if h != nil { + state, err := h.runtimeSnapshot() + if err == nil && state.usageStats != nil { + snapshot = state.usageStats.Snapshot() + } } c.JSON(http.StatusOK, usageExportPayload{ Version: 1, @@ -47,7 +53,12 @@ func (h *Handler) ExportUsageStatistics(c *gin.Context) { // ImportUsageStatistics merges a previously exported usage snapshot into memory. func (h *Handler) ImportUsageStatistics(c *gin.Context) { - if h == nil || h.usageStats == nil { + if h == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"}) + return + } + state, err := h.runtimeSnapshot() + if err != nil || state.usageStats == nil { c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"}) return } @@ -68,8 +79,8 @@ func (h *Handler) ImportUsageStatistics(c *gin.Context) { return } - result := h.usageStats.MergeSnapshot(payload.Usage) - snapshot := h.usageStats.Snapshot() + result := state.usageStats.MergeSnapshot(payload.Usage) + snapshot := state.usageStats.Snapshot() c.JSON(http.StatusOK, gin.H{ "added": result.Added, "skipped": result.Skipped, diff --git a/internal/api/handlers/management/vertex_import.go b/internal/api/handlers/management/vertex_import.go index bad066a270..10e0b8ed34 100644 --- a/internal/api/handlers/management/vertex_import.go +++ b/internal/api/handlers/management/vertex_import.go @@ -15,11 +15,12 @@ import ( // ImportVertexCredential handles uploading a Vertex service account JSON and saving it as an auth record. func (h *Handler) ImportVertexCredential(c *gin.Context) { - if h == nil || h.cfg == nil { + cfg, err := h.configSnapshot() + if h == nil || err != nil || cfg == nil { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "config unavailable"}) return } - if h.cfg.AuthDir == "" { + if cfg.AuthDir == "" { c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth directory not configured"}) return } diff --git a/internal/api/server.go b/internal/api/server.go index 0325ca30ce..b0c3e20b7c 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -20,6 +20,7 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/access" + configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access" managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management" "github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware" "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" @@ -41,6 +42,15 @@ import ( ) const oauthCallbackSuccessHTML = `
You can close this window.
This window will close automatically in 5 seconds.
` +const oauthCallbackFailureHTML = `Please return to the application and try again.
` + +func writePendingOAuthCallbackFile(provider, state, code, errStr string) error { + _, err := managementHandlers.WriteOAuthCallbackFileForPendingSession(provider, state, code, errStr) + if err != nil { + managementHandlers.SetOAuthSessionError(state, strings.TrimSpace(errStr)) + } + return err +} type serverOptionConfig struct { extraMiddleware []gin.HandlerFunc @@ -158,6 +168,8 @@ type Server struct { // management handler mgmt *managementHandlers.Handler + updateClientsMu sync.Mutex + // ampModule is the Amp routing module for model mapping hot-reload ampModule *ampmodule.AmpModule @@ -195,6 +207,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk for i := range opts { opts[i](optionState) } + if accessManager == nil { + log.Warn("access manager was nil, creating a default manager") + accessManager = sdkaccess.NewManager() + configaccess.Register(&cfg.SDKConfig) + accessManager.SetProviders(sdkaccess.RegisteredProviders()) + } + // Set gin mode if !cfg.Debug { gin.SetMode(gin.ReleaseMode) @@ -271,6 +290,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk if optionState.postAuthHook != nil { s.mgmt.SetPostAuthHook(optionState.postAuthHook) } + s.mgmt.SetRuntimeApplier(func(nextCfg *config.Config) error { + s.UpdateClients(nextCfg) + return nil + }) s.localPassword = optionState.localPassword // Setup routes @@ -363,7 +386,7 @@ func (s *Server) setupRoutes() { // OAuth callback endpoints (reuse main server port) // These endpoints receive provider redirects and persist // the short-lived code/state for the waiting goroutine. - s.engine.GET("/anthropic/callback", func(c *gin.Context) { + renderOAuthCallback := func(c *gin.Context, provider string) { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") @@ -371,66 +394,34 @@ func (s *Server) setupRoutes() { errStr = c.Query("error_description") } if state != "" { - _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr) + if err := writePendingOAuthCallbackFile(provider, state, code, errStr); err != nil { + log.Errorf("persist %s oauth callback failed: %v", provider, err) + c.Header("Content-Type", "text/html; charset=utf-8") + c.String(http.StatusInternalServerError, oauthCallbackFailureHTML) + return + } } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) + } + s.engine.GET("/anthropic/callback", func(c *gin.Context) { + renderOAuthCallback(c, "anthropic") }) s.engine.GET("/codex/callback", func(c *gin.Context) { - code := c.Query("code") - state := c.Query("state") - errStr := c.Query("error") - if errStr == "" { - errStr = c.Query("error_description") - } - if state != "" { - _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr) - } - c.Header("Content-Type", "text/html; charset=utf-8") - c.String(http.StatusOK, oauthCallbackSuccessHTML) + renderOAuthCallback(c, "codex") }) s.engine.GET("/google/callback", func(c *gin.Context) { - code := c.Query("code") - state := c.Query("state") - errStr := c.Query("error") - if errStr == "" { - errStr = c.Query("error_description") - } - if state != "" { - _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr) - } - c.Header("Content-Type", "text/html; charset=utf-8") - c.String(http.StatusOK, oauthCallbackSuccessHTML) + renderOAuthCallback(c, "gemini") }) s.engine.GET("/iflow/callback", func(c *gin.Context) { - code := c.Query("code") - state := c.Query("state") - errStr := c.Query("error") - if errStr == "" { - errStr = c.Query("error_description") - } - if state != "" { - _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr) - } - c.Header("Content-Type", "text/html; charset=utf-8") - c.String(http.StatusOK, oauthCallbackSuccessHTML) + renderOAuthCallback(c, "iflow") }) s.engine.GET("/antigravity/callback", func(c *gin.Context) { - code := c.Query("code") - state := c.Query("state") - errStr := c.Query("error") - if errStr == "" { - errStr = c.Query("error_description") - } - if state != "" { - _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr) - } - c.Header("Content-Type", "text/html; charset=utf-8") - c.String(http.StatusOK, oauthCallbackSuccessHTML) + renderOAuthCallback(c, "antigravity") }) // Management routes are registered lazily by registerManagementRoutes when a secret is configured. @@ -875,6 +866,9 @@ func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) { // - clients: The new slice of AI service clients // - cfg: The new application configuration func (s *Server) UpdateClients(cfg *config.Config) { + s.updateClientsMu.Lock() + defer s.updateClientsMu.Unlock() + // Reconstruct old config from YAML snapshot to avoid reference sharing issues var oldCfg *config.Config if len(s.oldConfigYaml) > 0 { @@ -1028,7 +1022,8 @@ func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) { func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc { return func(c *gin.Context) { if manager == nil { - c.Next() + log.Error("authentication middleware misconfigured: access manager is nil") + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "authentication unavailable"}) return } diff --git a/internal/api/server_test.go b/internal/api/server_test.go index f5c18aa167..546738bacc 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -10,6 +10,7 @@ import ( "time" gin "github.com/gin-gonic/gin" + managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management" proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" @@ -46,6 +47,53 @@ func newTestServer(t *testing.T) *Server { return NewServer(cfg, authManager, accessManager, configPath) } +func TestAuthMiddleware_NilManagerRejectsRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(AuthMiddleware(nil)) + r.GET("/protected", func(c *gin.Context) { + c.Status(http.StatusNoContent) + }) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("unexpected status: got %d body=%s", w.Code, w.Body.String()) + } +} + +func TestNewServer_CreatesAccessManagerWhenNil(t *testing.T) { + gin.SetMode(gin.TestMode) + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o700); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + + cfg := &proxyconfig.Config{ + SDKConfig: sdkconfig.SDKConfig{APIKeys: []string{"test-key"}}, + Port: 0, + AuthDir: authDir, + } + + server := NewServer(cfg, auth.NewManager(nil, nil, nil), nil, filepath.Join(tmpDir, "config.yaml")) + + unauthorized := httptest.NewRecorder() + server.engine.ServeHTTP(unauthorized, httptest.NewRequest(http.MethodGet, "/v1/models", nil)) + if unauthorized.Code != http.StatusUnauthorized { + t.Fatalf("expected unauthorized without credentials, got %d body=%s", unauthorized.Code, unauthorized.Body.String()) + } + + authorizedReq := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + authorizedReq.Header.Set("Authorization", "Bearer test-key") + authorized := httptest.NewRecorder() + server.engine.ServeHTTP(authorized, authorizedReq) + if authorized.Code != http.StatusOK { + t.Fatalf("expected authorized request to succeed, got %d body=%s", authorized.Code, authorized.Body.String()) + } +} func TestAmpProviderModelRoutes(t *testing.T) { testCases := []struct { name string @@ -208,3 +256,204 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) { } } } + +func TestOAuthCallbackRoute_UsesSessionAuthDir(t *testing.T) { + server := newTestServer(t) + + sessionAuthDir := filepath.Join(t.TempDir(), "session-auth") + if err := os.MkdirAll(sessionAuthDir, 0o700); err != nil { + t.Fatalf("failed to create session auth dir: %v", err) + } + + state := "codex-route-session-state" + managementHandlers.RegisterOAuthSession(state, "codex", sessionAuthDir) + + req := httptest.NewRequest(http.MethodGet, "/codex/callback?state="+state+"&code=route-code", nil) + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rr.Code, rr.Body.String()) + } + + sessionPath := filepath.Join(sessionAuthDir, ".oauth-codex-"+state+".oauth") + configPath := filepath.Join(server.cfg.AuthDir, ".oauth-codex-"+state+".oauth") + if _, err := os.Stat(sessionPath); err != nil { + t.Fatalf("expected callback file in session auth dir, stat err: %v", err) + } + if _, err := os.Stat(configPath); !os.IsNotExist(err) { + t.Fatalf("expected no callback file in config auth dir, stat err: %v", err) + } +} + +func TestOAuthCallbackRoute_ReturnsErrorWhenCallbackFileWriteFails(t *testing.T) { + server := newTestServer(t) + + blockedAuthDir := filepath.Join(t.TempDir(), "blocked-auth-dir") + if err := os.WriteFile(blockedAuthDir, []byte("blocked"), 0o600); err != nil { + t.Fatalf("failed to create blocking auth path: %v", err) + } + + state := "codex-route-write-failure" + managementHandlers.RegisterOAuthSession(state, "codex", blockedAuthDir) + + req := httptest.NewRequest(http.MethodGet, "/codex/callback?state="+state+"&code=route-code", nil) + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Fatalf("expected status %d, got %d with body %s", http.StatusInternalServerError, rr.Code, rr.Body.String()) + } + if strings.Contains(rr.Body.String(), "Authentication successful") { + t.Fatalf("expected failure response, got %s", rr.Body.String()) + } + if !strings.Contains(rr.Body.String(), "Authentication failed") { + t.Fatalf("expected failure response body, got %s", rr.Body.String()) + } + + sessionPath := filepath.Join(blockedAuthDir, ".oauth-codex-"+state+".oauth") + if _, err := os.Stat(sessionPath); !os.IsNotExist(err) { + t.Fatalf("expected no callback file when write fails, stat err: %v", err) + } + if managementHandlers.IsOAuthSessionPending(state, "codex") { + t.Fatal("expected callback write failure to end pending oauth session") + } + provider, status, ok := managementHandlers.GetOAuthSession(state) + if !ok { + t.Fatal("expected oauth session to remain available with error status") + } + if provider != "codex" { + t.Fatalf("expected provider codex, got %q", provider) + } + if strings.TrimSpace(status) == "" { + t.Fatal("expected oauth session error status to be recorded") + } +} + +func TestManagementPutWebsocketAuth_AppliesCommittedConfigAndUpdatesRoutes(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("MANAGEMENT_PASSWORD", "local-secret") + + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o700); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("api-keys:\n - test-key\nws-auth: false\n"), 0o600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + cfg := &proxyconfig.Config{ + SDKConfig: sdkconfig.SDKConfig{APIKeys: []string{"test-key"}}, + AuthDir: authDir, + } + server := NewServer(cfg, auth.NewManager(nil, nil, nil), nil, configPath, WithLocalManagementPassword("local-secret")) + server.AttachWebsocketRoute("/live-ws", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + + beforeReq := httptest.NewRequest(http.MethodGet, "/live-ws", nil) + beforeReq.RemoteAddr = "127.0.0.1:12345" + beforeResp := httptest.NewRecorder() + server.engine.ServeHTTP(beforeResp, beforeReq) + if beforeResp.Code != http.StatusNoContent { + t.Fatalf("expected websocket route to be open before ws-auth update, got %d body=%s", beforeResp.Code, beforeResp.Body.String()) + } + + mgmtReq := httptest.NewRequest(http.MethodPut, "/v0/management/ws-auth", strings.NewReader(`{"value":true}`)) + mgmtReq.RemoteAddr = "127.0.0.1:12345" + mgmtReq.Header.Set("Authorization", "Bearer local-secret") + mgmtReq.Header.Set("Content-Type", "application/json") + mgmtResp := httptest.NewRecorder() + server.engine.ServeHTTP(mgmtResp, mgmtReq) + if mgmtResp.Code != http.StatusOK { + t.Fatalf("expected management update to succeed, got %d body=%s", mgmtResp.Code, mgmtResp.Body.String()) + } + + if server.cfg == nil || !server.cfg.WebsocketAuth { + t.Fatalf("expected server config to reflect committed ws-auth update, got %+v", server.cfg) + } + + persisted, err := proxyconfig.LoadConfig(configPath) + if err != nil { + t.Fatalf("failed to load persisted config: %v", err) + } + if !persisted.WebsocketAuth { + t.Fatalf("expected persisted config to enable ws-auth, got %+v", persisted) + } + + afterReq := httptest.NewRequest(http.MethodGet, "/live-ws", nil) + afterReq.RemoteAddr = "127.0.0.1:12345" + afterResp := httptest.NewRecorder() + server.engine.ServeHTTP(afterResp, afterReq) + if afterResp.Code != http.StatusUnauthorized { + t.Fatalf("expected websocket route to require auth after ws-auth update, got %d body=%s", afterResp.Code, afterResp.Body.String()) + } + + authorizedReq := httptest.NewRequest(http.MethodGet, "/live-ws", nil) + authorizedReq.RemoteAddr = "127.0.0.1:12345" + authorizedReq.Header.Set("Authorization", "Bearer test-key") + authorizedResp := httptest.NewRecorder() + server.engine.ServeHTTP(authorizedResp, authorizedReq) + if authorizedResp.Code != http.StatusNoContent { + t.Fatalf("expected authorized websocket route to succeed after ws-auth update, got %d body=%s", authorizedResp.Code, authorizedResp.Body.String()) + } +} + +func TestManagementPutConfigYAML_AppliesCommittedConfigAndEnablesRequestLogging(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("MANAGEMENT_PASSWORD", "local-secret") + + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o700); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("api-keys:\n - test-key\nrequest-log: false\n"), 0o600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + cfg := &proxyconfig.Config{ + SDKConfig: sdkconfig.SDKConfig{ + APIKeys: []string{"test-key"}, + RequestLog: false, + }, + AuthDir: authDir, + } + server := NewServer(cfg, auth.NewManager(nil, nil, nil), nil, configPath, WithLocalManagementPassword("local-secret")) + if server.requestLogger == nil { + t.Fatal("expected request logger to be configured") + } + if server.requestLogger.IsEnabled() { + t.Fatal("expected request logger to start disabled") + } + + mgmtReq := httptest.NewRequest(http.MethodPut, "/v0/management/config.yaml", strings.NewReader("api-keys:\n - test-key\nrequest-log: true\n")) + mgmtReq.RemoteAddr = "127.0.0.1:12345" + mgmtReq.Header.Set("Authorization", "Bearer local-secret") + mgmtReq.Header.Set("Content-Type", "application/yaml") + mgmtResp := httptest.NewRecorder() + server.engine.ServeHTTP(mgmtResp, mgmtReq) + if mgmtResp.Code != http.StatusOK { + t.Fatalf("expected config.yaml update to succeed, got %d body=%s", mgmtResp.Code, mgmtResp.Body.String()) + } + + if server.cfg == nil || !server.cfg.RequestLog { + t.Fatalf("expected server config to reflect committed request-log update, got %+v", server.cfg) + } + if !server.requestLogger.IsEnabled() { + t.Fatal("expected request logger to be enabled after config.yaml update") + } + + persisted, err := proxyconfig.LoadConfig(configPath) + if err != nil { + t.Fatalf("failed to load persisted config: %v", err) + } + if !persisted.RequestLog { + t.Fatalf("expected persisted config to enable request-log, got %+v", persisted) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index a11c741efc..0d1d4ea717 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,6 +21,7 @@ import ( const ( DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center" DefaultPprofAddr = "127.0.0.1:8316" + MaxLogsMaxTotalSizeMB = 1024 ) // Config represents the application's configuration, loaded from a YAML file. @@ -608,6 +609,8 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { if cfg.LogsMaxTotalSizeMB < 0 { cfg.LogsMaxTotalSizeMB = 0 + } else if cfg.LogsMaxTotalSizeMB > MaxLogsMaxTotalSizeMB { + cfg.LogsMaxTotalSizeMB = MaxLogsMaxTotalSizeMB } if cfg.ErrorLogsMaxFiles < 0 { @@ -993,6 +996,7 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error { pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-excluded-models") pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-model-alias") + pruneAmpCodeGeneratedKey(original.Content[0], generated.Content[0], "upstream-api-keys") // Merge generated into original in-place, preserving comments/order of existing nodes. mergeMappingPreserve(original.Content[0], generated.Content[0]) @@ -1635,6 +1639,29 @@ func pruneMissingMapKeys(dstMap, srcMap *yaml.Node) { } } +func pruneAmpCodeGeneratedKey(dstRoot, srcRoot *yaml.Node, key string) { + if key == "" || dstRoot == nil || srcRoot == nil || dstRoot.Kind != yaml.MappingNode || srcRoot.Kind != yaml.MappingNode { + return + } + dstIdx := findMapKeyIndex(dstRoot, "ampcode") + if dstIdx < 0 || dstIdx+1 >= len(dstRoot.Content) { + return + } + dstAmp := dstRoot.Content[dstIdx+1] + if dstAmp == nil || dstAmp.Kind != yaml.MappingNode { + return + } + srcIdx := findMapKeyIndex(srcRoot, "ampcode") + if srcIdx < 0 || srcIdx+1 >= len(srcRoot.Content) { + return + } + srcAmp := srcRoot.Content[srcIdx+1] + if srcAmp == nil || srcAmp.Kind != yaml.MappingNode { + return + } + pruneMappingToGeneratedKeys(dstAmp, srcAmp, key) +} + // normalizeCollectionNodeStyles forces YAML collections to use block notation, keeping // lists and maps readable. Empty sequences retain flow style ([]) so empty list markers // remain compact. diff --git a/internal/config/config_limits_test.go b/internal/config/config_limits_test.go new file mode 100644 index 0000000000..0a027b346b --- /dev/null +++ b/internal/config/config_limits_test.go @@ -0,0 +1,53 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLoadConfig_ClampsLogsMaxTotalSizeMB(t *testing.T) { + path := filepath.Join(t.TempDir(), "config.yaml") + if err := os.WriteFile(path, []byte("logs-max-total-size-mb: 1048577\n"), 0o600); err != nil { + t.Fatalf("write temp config: %v", err) + } + + cfg, err := LoadConfig(path) + if err != nil { + t.Fatalf("LoadConfig returned error: %v", err) + } + if cfg.LogsMaxTotalSizeMB != MaxLogsMaxTotalSizeMB { + t.Fatalf("expected logs-max-total-size-mb to be clamped to %d, got %d", MaxLogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB) + } +} + +func TestSaveConfigPreserveComments_PrunesRemovedAmpUpstreamAPIKeysWithoutDroppingUnknownAmpKeys(t *testing.T) { + path := filepath.Join(t.TempDir(), "config.yaml") + original := []byte("ampcode:\n upstream-url: https://example.com\n upstream-api-keys:\n - upstream-api-key: old\n api-keys:\n - key\n custom-extra: keep-me\n") + if err := os.WriteFile(path, original, 0o600); err != nil { + t.Fatalf("write temp config: %v", err) + } + + cfg := &Config{ + AmpCode: AmpCode{ + UpstreamURL: "https://example.com", + UpstreamAPIKeys: nil, + }, + } + if err := SaveConfigPreserveComments(path, cfg); err != nil { + t.Fatalf("SaveConfigPreserveComments returned error: %v", err) + } + + persisted, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read persisted config: %v", err) + } + persistedText := string(persisted) + if strings.Contains(persistedText, "upstream-api-keys") { + t.Fatalf("expected upstream-api-keys to be pruned, got %s", persistedText) + } + if !strings.Contains(persistedText, "custom-extra: keep-me") { + t.Fatalf("expected unknown ampcode key to be preserved, got %s", persistedText) + } +} diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 82b12a2f80..c7db451647 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -391,6 +391,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A log.Errorf("response body close error: %v", errClose) } }() + // Ensure every stream path, including Claude passthrough, records at least one usage entry. + defer reporter.ensurePublished(ctx) // If from == to (Claude → Claude), directly forward the SSE stream without translation if from == to { diff --git a/internal/runtime/executor/claude_executor_usage_test.go b/internal/runtime/executor/claude_executor_usage_test.go new file mode 100644 index 0000000000..2a7971b045 --- /dev/null +++ b/internal/runtime/executor/claude_executor_usage_test.go @@ -0,0 +1,105 @@ +package executor + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" +) + +type authScopedUsagePlugin struct { + authID string + records chan usage.Record +} + +func (p *authScopedUsagePlugin) HandleUsage(_ context.Context, record usage.Record) { + if p == nil || record.AuthID != p.authID { + return + } + select { + case p.records <- record: + default: + } +} + +var ( + claudePassthroughUsagePluginOnce sync.Once + claudePassthroughUsagePlugin = &authScopedUsagePlugin{ + authID: "claude-passthrough-no-usage", + records: make(chan usage.Record, 8), + } +) + +func waitForUsageRecord(t *testing.T, records <-chan usage.Record) usage.Record { + t.Helper() + select { + case record := <-records: + return record + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for usage record") + return usage.Record{} + } +} + +func TestClaudeExecutorExecuteStream_PassthroughPublishesFallbackUsageWithoutUsageChunk(t *testing.T) { + claudePassthroughUsagePluginOnce.Do(func() { + usage.RegisterPlugin(claudePassthroughUsagePlugin) + }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\"}}\n\n")) + _, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hi\"}}\n\n")) + _, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n")) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + ID: "claude-passthrough-no-usage", + Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }, + } + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + Stream: true, + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream chunk error: %v", chunk.Err) + } + } + + record := waitForUsageRecord(t, claudePassthroughUsagePlugin.records) + if record.AuthID != auth.ID { + t.Fatalf("usage record auth_id = %q, want %q", record.AuthID, auth.ID) + } + if record.Provider != "claude" { + t.Fatalf("usage record provider = %q, want %q", record.Provider, "claude") + } + if record.Failed { + t.Fatal("usage fallback should mark request as successful") + } + if record.Detail != (usage.Detail{}) { + t.Fatalf("usage fallback detail = %+v, want zero-value detail", record.Detail) + } +} diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 7c25b8935f..8d5ebfbe80 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -333,6 +333,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A reporter.publishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: errScan} } + // Ensure we record the request if no usage chunk was ever seen + reporter.ensurePublished(ctx) }() return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } diff --git a/internal/runtime/executor/kimi_executor.go b/internal/runtime/executor/kimi_executor.go index d5e3702f48..843f941dfe 100644 --- a/internal/runtime/executor/kimi_executor.go +++ b/internal/runtime/executor/kimi_executor.go @@ -283,6 +283,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut reporter.publishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: errScan} } + // Ensure we record the request if no usage chunk was ever seen + reporter.ensurePublished(ctx) }() return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 623c66206a..5e53d937ce 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -17,9 +17,12 @@ import ( cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) +const openAICompatRetryErrorBodyLimit = 1 << 20 + // OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers. // It performs request/response translation and executes against the provider base URL // using per-auth credentials (API key) and per-auth HTTP transport (proxy) from context. @@ -199,15 +202,22 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) requestedModel := payloadRequestedModel(opts, req.Model) translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel) + // Preserve historical behavior: if include_usage is omitted or explicitly + // sent as false/null, still force it on so upstreams can emit real usage + // chunks. Only an explicit true counts as caller-enabled. + autoInjectedStreamUsage := !gjson.GetBytes(translated, "stream_options.include_usage").Bool() translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, err } - // Request usage data in the final streaming chunk so that token statistics - // are captured even when the upstream is an OpenAI-compatible provider. - translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true) + if autoInjectedStreamUsage { + translated, err = sjson.SetBytes(translated, "stream_options.include_usage", true) + if err != nil { + return nil, fmt.Errorf("openai compat executor: failed to set stream_options in payload: %w", err) + } + } url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated)) @@ -250,6 +260,13 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy recordAPIResponseError(ctx, e.cfg, err) return nil, err } + if retryResp, retryErr := e.retryStreamWithoutInjectedUsage(ctx, auth, httpClient, httpReq, translated, httpResp, autoInjectedStreamUsage); retryResp != nil || retryErr != nil { + httpResp = retryResp + if retryErr != nil { + recordAPIResponseError(ctx, e.cfg, retryErr) + return nil, retryErr + } + } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) @@ -304,6 +321,95 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } +func (e *OpenAICompatExecutor) retryStreamWithoutInjectedUsage(ctx context.Context, auth *cliproxyauth.Auth, httpClient *http.Client, httpReq *http.Request, translated []byte, httpResp *http.Response, autoInjected bool) (*http.Response, error) { + if !autoInjected || httpResp == nil { + return nil, nil + } + if httpResp.StatusCode != http.StatusBadRequest && httpResp.StatusCode != http.StatusUnprocessableEntity { + return nil, nil + } + body, err := io.ReadAll(io.LimitReader(httpResp.Body, openAICompatRetryErrorBodyLimit+1)) + if err != nil { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Warnf("openai compat executor: failed to close body after read error: %v", errClose) + } + return nil, err + } + if errClose := httpResp.Body.Close(); errClose != nil { + log.Warnf("openai compat executor: close fallback response body error: %v", errClose) + } + if len(body) > openAICompatRetryErrorBodyLimit { + log.Warnf("openai compat executor: fallback response body exceeded %d bytes; skip retry without include_usage", openAICompatRetryErrorBodyLimit) + httpResp.Body = io.NopCloser(bytes.NewReader(body[:openAICompatRetryErrorBodyLimit])) + return httpResp, nil + } + if !isUnsupportedInjectedUsageError(body) { + httpResp.Body = io.NopCloser(bytes.NewReader(body)) + return httpResp, nil + } + trimmed, err := sjson.DeleteBytes(translated, "stream_options.include_usage") + if err != nil { + return nil, fmt.Errorf("openai compat executor: failed to remove unsupported stream_options in payload: %w", err) + } + if streamOptions := gjson.GetBytes(trimmed, "stream_options"); streamOptions.Exists() && len(streamOptions.Map()) == 0 { + trimmed, err = sjson.DeleteBytes(trimmed, "stream_options") + if err != nil { + return nil, fmt.Errorf("openai compat executor: failed to remove empty stream_options in payload: %w", err) + } + } + retryReq, err := http.NewRequestWithContext(ctx, httpReq.Method, httpReq.URL.String(), bytes.NewReader(trimmed)) + if err != nil { + return nil, err + } + retryReq.Header = httpReq.Header.Clone() + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: retryReq.URL.String(), + Method: retryReq.Method, + Headers: retryReq.Header.Clone(), + Body: trimmed, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + return httpClient.Do(retryReq) +} + +func isUnsupportedInjectedUsageError(body []byte) bool { + if len(body) == 0 { + return false + } + lower := strings.ToLower(string(body)) + if !strings.Contains(lower, "stream_options") && !strings.Contains(lower, "include_usage") { + return false + } + unsupportedMarkers := []string{ + "unknown field", + "unknown parameter", + "unknown argument", + "unrecognized field", + "unrecognized parameter", + "unsupported field", + "unsupported parameter", + "not allowed", + "not permitted", + "extra inputs are not permitted", + } + for _, marker := range unsupportedMarkers { + if strings.Contains(lower, marker) { + return true + } + } + return false +} + func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { baseModel := thinking.ParseSuffix(req.Model).ModelName diff --git a/internal/runtime/executor/openai_compat_executor_stream_test.go b/internal/runtime/executor/openai_compat_executor_stream_test.go new file mode 100644 index 0000000000..70ee90ae95 --- /dev/null +++ b/internal/runtime/executor/openai_compat_executor_stream_test.go @@ -0,0 +1,242 @@ +package executor + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/gjson" +) + +func TestOpenAICompatExecutorExecuteStreamSetsIncludeUsage(t *testing.T) { + var gotBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + gotBody, err = io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(`data: {"id":"chatcmpl-1","choices":[{"index":0,"delta":{"content":"hi"}}]} + +`)) + _, _ = w.Write([]byte(`data: [DONE] + +`)) + })) + defer server.Close() + + executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL + "/v1", + "api_key": "test", + }} + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "gpt-4o-mini", + Payload: []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}]}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: true, + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream chunk error: %v", chunk.Err) + } + } + if !gjson.GetBytes(gotBody, "stream_options.include_usage").Bool() { + t.Fatalf("expected stream_options.include_usage=true, got body: %s", string(gotBody)) + } +} + +func TestOpenAICompatExecutorExecuteStreamRetriesWithoutInjectedIncludeUsage(t *testing.T) { + var gotBodies [][]byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + gotBodies = append(gotBodies, body) + if len(gotBodies) == 1 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"unknown field stream_options.include_usage"}}`)) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(`data: {"id":"chatcmpl-1","choices":[{"index":0,"delta":{"content":"hi"}}]} + +`)) + _, _ = w.Write([]byte(`data: [DONE] + +`)) + })) + defer server.Close() + + executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL + "/v1", + "api_key": "test", + }} + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "gpt-4o-mini", + Payload: []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}]}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: true, + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream chunk error: %v", chunk.Err) + } + } + if len(gotBodies) != 2 { + t.Fatalf("expected 2 upstream requests, got %d", len(gotBodies)) + } + if !gjson.GetBytes(gotBodies[0], "stream_options.include_usage").Bool() { + t.Fatalf("expected first request to include injected usage flag, got body: %s", string(gotBodies[0])) + } + if gjson.GetBytes(gotBodies[1], "stream_options.include_usage").Exists() { + t.Fatalf("expected retry request to remove include_usage, got body: %s", string(gotBodies[1])) + } + if gjson.GetBytes(gotBodies[1], "stream_options").Exists() { + t.Fatalf("expected retry request to remove empty stream_options, got body: %s", string(gotBodies[1])) + } +} + +func TestOpenAICompatExecutorExecuteStreamDoesNotRetryUnrelatedValidationErrors(t *testing.T) { + var gotBodies [][]byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + gotBodies = append(gotBodies, body) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"error":{"message":"validation failed for messages[0].content; request trace mentions stream_options.include_usage but the actual issue is content shape"}}`)) + })) + defer server.Close() + + executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL + "/v1", + "api_key": "test", + }} + _, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "gpt-4o-mini", + Payload: []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}]}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: true, + }) + if err == nil { + t.Fatal("expected ExecuteStream to return validation error") + } + if len(gotBodies) != 1 { + t.Fatalf("expected 1 upstream request for unrelated validation error, got %d", len(gotBodies)) + } + if !gjson.GetBytes(gotBodies[0], "stream_options.include_usage").Bool() { + t.Fatalf("expected original request to include injected usage flag, got body: %s", string(gotBodies[0])) + } + if statusProvider, ok := err.(interface{ StatusCode() int }); !ok || statusProvider.StatusCode() != http.StatusUnprocessableEntity { + t.Fatalf("expected status code 422, got: %v", err) + } +} + +func TestOpenAICompatExecutorExecuteStreamForcesIncludeUsageWhenCallerSendsFalse(t *testing.T) { + var gotBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + gotBody, err = io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(`data: {"id":"chatcmpl-1","choices":[{"index":0,"delta":{"content":"hi"}}]} + +`)) + _, _ = w.Write([]byte(`data: [DONE] + +`)) + })) + defer server.Close() + + executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL + "/v1", + "api_key": "test", + }} + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "gpt-4o-mini", + Payload: []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}],"stream_options":{"include_usage":false}}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: true, + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream chunk error: %v", chunk.Err) + } + } + if !gjson.GetBytes(gotBody, "stream_options.include_usage").Bool() { + t.Fatalf("expected include_usage to be forced to true, got body: %s", string(gotBody)) + } +} + +func TestOpenAICompatExecutorExecuteStreamForcesIncludeUsageWhenCallerSendsNull(t *testing.T) { + var gotBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + gotBody, err = io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(`data: {"id":"chatcmpl-1","choices":[{"index":0,"delta":{"content":"hi"}}]} + +`)) + _, _ = w.Write([]byte(`data: [DONE] + +`)) + })) + defer server.Close() + + executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL + "/v1", + "api_key": "test", + }} + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "gpt-4o-mini", + Payload: []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}],"stream_options":{"include_usage":null}}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: true, + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream chunk error: %v", chunk.Err) + } + } + if !gjson.GetBytes(gotBody, "stream_options.include_usage").Bool() { + t.Fatalf("expected include_usage to be forced to true, got body: %s", string(gotBody)) + } +} diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index e7957d2918..7ff4a4e85c 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -433,6 +433,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut reporter.publishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: errScan} } + // Ensure we record the request if no usage chunk was ever seen + reporter.ensurePublished(ctx) }() return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } diff --git a/internal/runtime/executor/qwen_executor_test.go b/internal/runtime/executor/qwen_executor_test.go index 6a777c53c5..e69486f419 100644 --- a/internal/runtime/executor/qwen_executor_test.go +++ b/internal/runtime/executor/qwen_executor_test.go @@ -1,9 +1,26 @@ package executor import ( + "context" + "net/http" + "net/http/httptest" + "sync" "testing" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" +) + +var ( + qwenPassthroughUsagePluginOnce sync.Once + qwenPassthroughUsagePlugin = &authScopedUsagePlugin{ + authID: "qwen-passthrough-no-usage", + records: make(chan usage.Record, 8), + } ) func TestQwenExecutorParseSuffix(t *testing.T) { @@ -28,3 +45,56 @@ func TestQwenExecutorParseSuffix(t *testing.T) { }) } } + +func TestQwenExecutorExecuteStream_PublishesFallbackUsageWithoutUsageChunk(t *testing.T) { + qwenPassthroughUsagePluginOnce.Do(func() { + usage.RegisterPlugin(qwenPassthroughUsagePlugin) + }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n")) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + })) + defer server.Close() + + executor := NewQwenExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + ID: "qwen-passthrough-no-usage", + Attributes: map[string]string{ + "api_key": "token-123", + "base_url": server.URL, + }, + } + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "qwen-max", + Payload: []byte(`{"model":"qwen-max","messages":[{"role":"user","content":"hi"}]}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: true, + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream chunk error: %v", chunk.Err) + } + } + + record := waitForUsageRecord(t, qwenPassthroughUsagePlugin.records) + if record.AuthID != auth.ID { + t.Fatalf("usage record auth_id = %q, want %q", record.AuthID, auth.ID) + } + if record.Provider != "qwen" { + t.Fatalf("usage record provider = %q, want %q", record.Provider, "qwen") + } + if record.Failed { + t.Fatal("usage fallback should mark request as successful") + } + if record.Detail != (usage.Detail{}) { + t.Fatalf("usage fallback detail = %+v, want zero-value detail", record.Detail) + } +} diff --git a/internal/runtime/executor/usage_helpers.go b/internal/runtime/executor/usage_helpers.go index 00f547df22..4212eecfe1 100644 --- a/internal/runtime/executor/usage_helpers.go +++ b/internal/runtime/executor/usage_helpers.go @@ -185,12 +185,8 @@ func parseCodexUsage(data []byte) (usage.Detail, bool) { OutputTokens: usageNode.Get("output_tokens").Int(), TotalTokens: usageNode.Get("total_tokens").Int(), } - if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() { - detail.CachedTokens = cached.Int() - } - if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() { - detail.ReasoningTokens = reasoning.Int() - } + detail.CachedTokens = usageNode.Get("input_tokens_details.cached_tokens").Int() + detail.ReasoningTokens = usageNode.Get("output_tokens_details.reasoning_tokens").Int() return detail, true } @@ -216,16 +212,12 @@ func parseOpenAIUsage(data []byte) usage.Detail { if !cached.Exists() { cached = usageNode.Get("input_tokens_details.cached_tokens") } - if cached.Exists() { - detail.CachedTokens = cached.Int() - } + detail.CachedTokens = cached.Int() reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens") if !reasoning.Exists() { reasoning = usageNode.Get("output_tokens_details.reasoning_tokens") } - if reasoning.Exists() { - detail.ReasoningTokens = reasoning.Int() - } + detail.ReasoningTokens = reasoning.Int() return detail } @@ -238,17 +230,29 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) { if !usageNode.Exists() { return usage.Detail{}, false } + inputNode := usageNode.Get("prompt_tokens") + if !inputNode.Exists() { + inputNode = usageNode.Get("input_tokens") + } + outputNode := usageNode.Get("completion_tokens") + if !outputNode.Exists() { + outputNode = usageNode.Get("output_tokens") + } detail := usage.Detail{ - InputTokens: usageNode.Get("prompt_tokens").Int(), - OutputTokens: usageNode.Get("completion_tokens").Int(), + InputTokens: inputNode.Int(), + OutputTokens: outputNode.Int(), TotalTokens: usageNode.Get("total_tokens").Int(), } - if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() { - detail.CachedTokens = cached.Int() + cached := usageNode.Get("prompt_tokens_details.cached_tokens") + if !cached.Exists() { + cached = usageNode.Get("input_tokens_details.cached_tokens") } - if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() { - detail.ReasoningTokens = reasoning.Int() + detail.CachedTokens = cached.Int() + reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens") + if !reasoning.Exists() { + reasoning = usageNode.Get("output_tokens_details.reasoning_tokens") } + detail.ReasoningTokens = reasoning.Int() return detail, true } diff --git a/internal/runtime/executor/usage_helpers_test.go b/internal/runtime/executor/usage_helpers_test.go index 337f108af7..0a68f8de3d 100644 --- a/internal/runtime/executor/usage_helpers_test.go +++ b/internal/runtime/executor/usage_helpers_test.go @@ -41,3 +41,26 @@ func TestParseOpenAIUsageResponses(t *testing.T) { t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9) } } + +func TestParseOpenAIStreamUsageResponses(t *testing.T) { + line := []byte(`data: {"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`) + detail, ok := parseOpenAIStreamUsage(line) + if !ok { + t.Fatal("expected stream usage to be parsed") + } + if detail.InputTokens != 10 { + t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10) + } + if detail.OutputTokens != 20 { + t.Fatalf("output tokens = %d, want %d", detail.OutputTokens, 20) + } + if detail.TotalTokens != 30 { + t.Fatalf("total tokens = %d, want %d", detail.TotalTokens, 30) + } + if detail.CachedTokens != 7 { + t.Fatalf("cached tokens = %d, want %d", detail.CachedTokens, 7) + } + if detail.ReasoningTokens != 9 { + t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9) + } +} diff --git a/test/amp_management_test.go b/test/amp_management_test.go index e384ef0e8b..88dd2d7279 100644 --- a/test/amp_management_test.go +++ b/test/amp_management_test.go @@ -404,8 +404,8 @@ func TestDeleteAmpModelMappings_Specific(t *testing.T) { } } -// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings. -func TestDeleteAmpModelMappings_All(t *testing.T) { +// TestDeleteAmpModelMappings_EmptyBody verifies DELETE with empty body is rejected and does not clear mappings. +func TestDeleteAmpModelMappings_EmptyBody(t *testing.T) { h, _ := newAmpTestHandler(t) r := setupAmpRouter(h) @@ -413,9 +413,129 @@ func TestDeleteAmpModelMappings_All(t *testing.T) { w := httptest.NewRecorder() r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if len(resp["model-mappings"]) != 1 { + t.Fatalf("expected original mappings to remain, got %d", len(resp["model-mappings"])) + } +} + +// TestDeleteAmpModelMappings_InvalidJSON verifies invalid JSON is rejected and does not clear mappings. +func TestDeleteAmpModelMappings_InvalidJSON(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString("{bad json")) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if len(resp["model-mappings"]) != 1 { + t.Fatalf("expected original mappings to remain, got %d", len(resp["model-mappings"])) + } +} + +// TestDeleteAmpModelMappings_MissingValue verifies missing value is rejected and does not clear mappings. +func TestDeleteAmpModelMappings_MissingValue(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if len(resp["model-mappings"]) != 1 { + t.Fatalf("expected original mappings to remain, got %d", len(resp["model-mappings"])) + } +} + +// TestDeleteAmpModelMappings_NullValue verifies null value is rejected and does not clear mappings. +func TestDeleteAmpModelMappings_NullValue(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(`{"value":null}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if len(resp["model-mappings"]) != 1 { + t.Fatalf("expected original mappings to remain, got %d", len(resp["model-mappings"])) + } +} + +// TestDeleteAmpModelMappings_EmptyArrayClearsAll verifies an explicit empty array clears all mappings. +func TestDeleteAmpModelMappings_EmptyArrayClearsAll(t *testing.T) { + h, _ := newAmpTestHandler(t) + r := setupAmpRouter(h) + + req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(`{"value":[]}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) } + + req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil) + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + var resp map[string][]config.AmpModelMapping + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if len(resp["model-mappings"]) != 0 { + t.Fatalf("expected mappings to be cleared, got %d", len(resp["model-mappings"])) + } } // TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting.