diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 0e490e3202..dc89a7360c 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -208,6 +208,11 @@ func requestExecutionMetadata(ctx context.Context) map[string]any { if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" { meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID } + // Sticky session: forward X-CLIProxyAPI-Session-ID header so the conductor can pin + // subsequent requests from the same session to the same auth account. + if sessionID := stickySessionIDFromHeader(ctx); sessionID != "" { + meta[coreexecutor.StickySessionMetadataKey] = sessionID + } return meta } @@ -252,6 +257,21 @@ func executionSessionIDFromContext(ctx context.Context) string { } } +func stickySessionIDFromHeader(ctx context.Context) string { + if ctx == nil { + return "" + } + ginCtx, ok := ctx.Value("gin").(*gin.Context) + if !ok || ginCtx == nil || ginCtx.Request == nil { + return "" + } + id := strings.TrimSpace(ginCtx.GetHeader("X-CLIProxyAPI-Session-ID")) + if len(id) > coreauth.StickyMaxSessionIDLen { + return "" + } + return id +} + // BaseAPIHandler contains the handlers for API endpoints. // It holds a pool of clients to interact with the backend service and manages // load balancing, client selection, and configuration. diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index b29e04db8c..a2b70b4b45 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -160,6 +160,9 @@ type Manager struct { // Optional HTTP RoundTripper provider injected by host. rtProvider RoundTripperProvider + // sticky maintains session-to-auth bindings for sticky routing. + sticky *stickyStore + // Auto refresh state refreshCancel context.CancelFunc refreshSemaphore chan struct{} @@ -181,6 +184,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { auths: make(map[string]*Auth), providerOffsets: make(map[string]int), modelPoolOffsets: make(map[string]int), + sticky: newStickyStore(), refreshSemaphore: make(chan struct{}, refreshMaxConcurrency), } // atomic.Value requires non-nil initial value. @@ -483,7 +487,7 @@ func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamC } } -func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult { +func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk, stickyKey string) *cliproxyexecutor.StreamResult { out := make(chan cliproxyexecutor.StreamChunk) go func() { defer close(out) @@ -497,6 +501,12 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro rerr.HTTPStatus = se.StatusCode() } m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}) + // Sticky session: clear binding on mid-stream error. + if stickyKey != "" { + if sc := statusCodeFromResult(rerr); sc == 429 || sc >= 500 { + m.sticky.Delete(stickyKey) + } + } } if !forward { return false @@ -532,9 +542,9 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out} } -func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) { +func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, stickyKey string) (*cliproxyexecutor.StreamResult, bool, error) { if executor == nil { - return nil, &Error{Code: "executor_not_found", Message: "executor not registered"} + return nil, false, &Error{Code: "executor_not_found", Message: "executor not registered"} } execModels := m.prepareExecutionModels(auth, routeModel) var lastErr error @@ -544,7 +554,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts) if errStream != nil { if errCtx := ctx.Err(); errCtx != nil { - return nil, errCtx + return nil, false, errCtx } rerr := &Error{Message: errStream.Error()} if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { @@ -554,7 +564,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi result.RetryAfter = retryAfterFromError(errStream) m.MarkResult(ctx, result) if isRequestInvalidError(errStream) { - return nil, errStream + return nil, false, errStream } lastErr = errStream continue @@ -564,7 +574,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi if bootstrapErr != nil { if errCtx := ctx.Err(); errCtx != nil { discardStreamChunks(streamResult.Chunks) - return nil, errCtx + return nil, false, errCtx } if isRequestInvalidError(bootstrapErr) { rerr := &Error{Message: bootstrapErr.Error()} @@ -575,7 +585,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi result.RetryAfter = retryAfterFromError(bootstrapErr) m.MarkResult(ctx, result) discardStreamChunks(streamResult.Chunks) - return nil, bootstrapErr + return nil, false, bootstrapErr } if idx < len(execModels)-1 { rerr := &Error{Message: bootstrapErr.Error()} @@ -592,7 +602,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi errCh := make(chan cliproxyexecutor.StreamChunk, 1) errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr} close(errCh) - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickyKey), false, nil } if closed && len(buffered) == 0 { @@ -606,7 +616,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi errCh := make(chan cliproxyexecutor.StreamChunk, 1) errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr} close(errCh) - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickyKey), false, nil } remaining := streamResult.Chunks @@ -615,12 +625,12 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi close(closedCh) remaining = closedCh } - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining, stickyKey), true, nil } if lastErr == nil { lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"} } - return nil, lastErr + return nil, false, lastErr } func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() { @@ -978,6 +988,10 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req } routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) + + stickySessionID := m.resolveStickySession(opts.Metadata, routeModel) + sk := stickyKey(stickySessionID, routeModel) + tried := make(map[string]struct{}) var lastErr error for { @@ -1025,12 +1039,24 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req result.RetryAfter = ra } m.MarkResult(execCtx, result) + // Sticky session: clear binding on rate-limit or server error so next + // request falls back to normal auth selection. + if sk != "" { + if sc := statusCodeFromResult(result.Error); sc == 429 || sc >= 500 { + m.sticky.Delete(sk) + delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey) + } + } if isRequestInvalidError(errExec) { return cliproxyexecutor.Response{}, errExec } authErr = errExec continue } + // Sticky session: bind session to the auth that succeeded. + if sk != "" { + m.sticky.Set(sk, auth.ID, stickySessionTTL) + } m.MarkResult(execCtx, result) return resp, nil } @@ -1122,6 +1148,10 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string } routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) + + stickySessionID := m.resolveStickySession(opts.Metadata, routeModel) + sk := stickyKey(stickySessionID, routeModel) + tried := make(map[string]struct{}) var lastErr error for { @@ -1149,17 +1179,28 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel) + streamResult, streamOK, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, sk) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { return nil, errCtx } + // Sticky session: clear binding on rate-limit or server error. + if sk != "" { + if sc := statusCodeFromError(errStream); sc == 429 || sc >= 500 { + m.sticky.Delete(sk) + delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey) + } + } if isRequestInvalidError(errStream) { return nil, errStream } lastErr = errStream continue } + // Sticky session: only bind on genuine stream success. + if sk != "" && streamOK { + m.sticky.Set(sk, auth.ID, stickySessionTTL) + } return streamResult, nil } } @@ -1221,6 +1262,50 @@ func pinnedAuthIDFromMetadata(meta map[string]any) string { } } +// stickySessionTTL is the duration a session-to-auth binding remains valid. +const stickySessionTTL = time.Hour + +// stickyKey builds a composite store key from session ID and model so that +// the same session ID used with different models gets independent bindings. +func stickyKey(sessionID, model string) string { + return sessionID + "|" + model +} + +func stickySessionIDFromMetadata(meta map[string]any) string { + if len(meta) == 0 { + return "" + } + raw, ok := meta[cliproxyexecutor.StickySessionMetadataKey] + if !ok || raw == nil { + return "" + } + switch val := raw.(type) { + case string: + return strings.TrimSpace(val) + case []byte: + return strings.TrimSpace(string(val)) + default: + return "" + } +} + +// resolveStickySession extracts the sticky session ID from metadata and, if a +// valid binding exists, sets the pinned auth ID. An explicit pinned_auth_id +// in the metadata takes precedence over any sticky binding. +// routeModel scopes the lookup so the same session ID with different models +// gets independent auth bindings. +func (m *Manager) resolveStickySession(meta map[string]any, routeModel string) string { + id := stickySessionIDFromMetadata(meta) + if id != "" { + if _, alreadyPinned := meta[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned { + if boundAuth, found := m.sticky.Get(stickyKey(id, routeModel)); found { + meta[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth + } + } + } + return id +} + func publishSelectedAuthMetadata(meta map[string]any, authID string) { if len(meta) == 0 { return @@ -2331,6 +2416,8 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio go func() { ticker := time.NewTicker(interval) defer ticker.Stop() + lastStickyCleanup := time.Now() + const stickyCleanupInterval = 5 * time.Minute m.checkRefreshes(ctx) for { select { @@ -2338,6 +2425,10 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio return case <-ticker.C: m.checkRefreshes(ctx) + if time.Since(lastStickyCleanup) >= stickyCleanupInterval { + m.sticky.Cleanup() + lastStickyCleanup = time.Now() + } } } }() diff --git a/sdk/cliproxy/auth/sticky.go b/sdk/cliproxy/auth/sticky.go new file mode 100644 index 0000000000..c0eeede850 --- /dev/null +++ b/sdk/cliproxy/auth/sticky.go @@ -0,0 +1,83 @@ +package auth + +import ( + "sync" + "time" +) + +// stickyStore maintains session-to-auth bindings so that requests carrying the +// same session ID are routed to the same auth/account. Entries expire after a +// configurable TTL and are garbage-collected by Cleanup. +// +// maxEntries caps the number of stored bindings to prevent memory exhaustion +// from untrusted X-CLIProxyAPI-Session-ID headers. +type stickyStore struct { + mu sync.RWMutex + entries map[string]stickyEntry + maxEntries int +} + +type stickyEntry struct { + authID string + expiresAt time.Time +} + +// stickyMaxEntries is the upper bound on stored session bindings. +const stickyMaxEntries = 10_000 + +// StickyMaxSessionIDLen limits the accepted X-CLIProxyAPI-Session-ID length. +const StickyMaxSessionIDLen = 256 + +func newStickyStore() *stickyStore { + return &stickyStore{entries: make(map[string]stickyEntry), maxEntries: stickyMaxEntries} +} + +// Get returns the bound auth ID for the given session, if it exists and has not +// expired. +func (s *stickyStore) Get(sessionID string) (string, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + e, ok := s.entries[sessionID] + if !ok || time.Now().After(e.expiresAt) { + return "", false + } + return e.authID, true +} + +// Set binds a session to an auth ID with the specified TTL. +// If the store is at capacity, the write is silently dropped. +func (s *stickyStore) Set(sessionID, authID string, ttl time.Duration) { + s.mu.Lock() + if _, exists := s.entries[sessionID]; !exists && len(s.entries) >= s.maxEntries { + s.mu.Unlock() + return + } + s.entries[sessionID] = stickyEntry{authID: authID, expiresAt: time.Now().Add(ttl)} + s.mu.Unlock() +} + +// Delete removes the binding for the given session ID. +func (s *stickyStore) Delete(sessionID string) { + s.mu.Lock() + delete(s.entries, sessionID) + s.mu.Unlock() +} + +// Cleanup removes all expired entries. +func (s *stickyStore) Cleanup() { + now := time.Now() + s.mu.Lock() + for k, e := range s.entries { + if now.After(e.expiresAt) { + delete(s.entries, k) + } + } + s.mu.Unlock() +} + +// Len returns the number of entries (including possibly-expired ones). +func (s *stickyStore) Len() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.entries) +} diff --git a/sdk/cliproxy/auth/sticky_test.go b/sdk/cliproxy/auth/sticky_test.go new file mode 100644 index 0000000000..5b1d67512c --- /dev/null +++ b/sdk/cliproxy/auth/sticky_test.go @@ -0,0 +1,156 @@ +package auth + +import ( + "testing" + "time" +) + +func TestStickyStore_SetAndGet(t *testing.T) { + s := newStickyStore() + s.Set("sess-1", "auth-ai", time.Hour) + + got, ok := s.Get("sess-1") + if !ok || got != "auth-ai" { + t.Fatalf("expected auth-ai, got %q (ok=%v)", got, ok) + } +} + +func TestStickyStore_GetMiss(t *testing.T) { + s := newStickyStore() + _, ok := s.Get("nonexistent") + if ok { + t.Fatal("expected miss for nonexistent session") + } +} + +func TestStickyStore_GetExpired(t *testing.T) { + s := newStickyStore() + s.Set("sess-1", "auth-ai", time.Millisecond) + time.Sleep(2 * time.Millisecond) + + _, ok := s.Get("sess-1") + if ok { + t.Fatal("expected miss for expired entry") + } +} + +func TestStickyStore_Delete(t *testing.T) { + s := newStickyStore() + s.Set("sess-1", "auth-ai", time.Hour) + s.Delete("sess-1") + + _, ok := s.Get("sess-1") + if ok { + t.Fatal("expected miss after delete") + } +} + +func TestStickyStore_Overwrite(t *testing.T) { + s := newStickyStore() + s.Set("sess-1", "auth-ai", time.Hour) + s.Set("sess-1", "auth-cc", time.Hour) + + got, ok := s.Get("sess-1") + if !ok || got != "auth-cc" { + t.Fatalf("expected auth-cc after overwrite, got %q", got) + } +} + +func TestStickyStore_Cleanup(t *testing.T) { + s := newStickyStore() + s.Set("expired", "auth-ai", time.Millisecond) + s.Set("alive", "auth-cc", time.Hour) + time.Sleep(2 * time.Millisecond) + + s.Cleanup() + + if s.Len() != 1 { + t.Fatalf("expected 1 entry after cleanup, got %d", s.Len()) + } + _, ok := s.Get("alive") + if !ok { + t.Fatal("alive entry should still exist") + } +} + +func TestStickyStore_Len(t *testing.T) { + s := newStickyStore() + if s.Len() != 0 { + t.Fatalf("expected 0, got %d", s.Len()) + } + s.Set("a", "x", time.Hour) + s.Set("b", "y", time.Hour) + if s.Len() != 2 { + t.Fatalf("expected 2, got %d", s.Len()) + } +} + +func TestStickyStore_MaxEntries(t *testing.T) { + s := newStickyStore() + s.maxEntries = 2 + + s.Set("a", "x", time.Hour) + s.Set("b", "y", time.Hour) + s.Set("c", "z", time.Hour) // should be silently dropped + + if s.Len() != 2 { + t.Fatalf("expected 2 (capped), got %d", s.Len()) + } + if _, ok := s.Get("c"); ok { + t.Fatal("entry 'c' should have been dropped due to capacity") + } + // overwriting existing entry should still work at capacity + s.Set("a", "updated", time.Hour) + got, ok := s.Get("a") + if !ok || got != "updated" { + t.Fatalf("expected 'updated' for overwrite at capacity, got %q (ok=%v)", got, ok) + } +} + +func TestStickyKey(t *testing.T) { + cases := []struct { + sessionID string + model string + want string + }{ + {"sess-1", "claude-3-opus", "sess-1|claude-3-opus"}, + {"sess-1", "claude-3-sonnet", "sess-1|claude-3-sonnet"}, + {"", "claude-3-opus", "|claude-3-opus"}, + {"sess-1", "", "sess-1|"}, + } + for _, tc := range cases { + got := stickyKey(tc.sessionID, tc.model) + if got != tc.want { + t.Errorf("stickyKey(%q, %q) = %q, want %q", tc.sessionID, tc.model, got, tc.want) + } + } +} + +func TestStickyStore_CompositeKey(t *testing.T) { + s := newStickyStore() + + // Same session ID, different models → independent bindings + k1 := stickyKey("sess-1", "claude-3-opus") + k2 := stickyKey("sess-1", "claude-3-sonnet") + + s.Set(k1, "auth-cc", time.Hour) + s.Set(k2, "auth-ai", time.Hour) + + got1, ok1 := s.Get(k1) + if !ok1 || got1 != "auth-cc" { + t.Fatalf("expected auth-cc for opus key, got %q (ok=%v)", got1, ok1) + } + got2, ok2 := s.Get(k2) + if !ok2 || got2 != "auth-ai" { + t.Fatalf("expected auth-ai for sonnet key, got %q (ok=%v)", got2, ok2) + } + + // Deleting one doesn't affect the other + s.Delete(k1) + if _, ok := s.Get(k1); ok { + t.Fatal("expected miss after deleting opus key") + } + if _, ok := s.Get(k2); !ok { + t.Fatal("sonnet key should still exist after deleting opus key") + } +} diff --git a/sdk/cliproxy/executor/types.go b/sdk/cliproxy/executor/types.go index 4ea8103947..88535c4d45 100644 --- a/sdk/cliproxy/executor/types.go +++ b/sdk/cliproxy/executor/types.go @@ -19,6 +19,8 @@ const ( SelectedAuthCallbackMetadataKey = "selected_auth_callback" // ExecutionSessionMetadataKey identifies a long-lived downstream execution session. ExecutionSessionMetadataKey = "execution_session_id" + // StickySessionMetadataKey carries the session ID for sticky auth routing. + StickySessionMetadataKey = "sticky_session_id" ) // Request encapsulates the translated payload that will be sent to a provider executor.