Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions sdk/api/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
}
// Sticky session: forward X-CLIProxyAPI-Session-ID header so the conductor can pin
// subsequent requests from the same session to the same auth account.
if sessionID := stickySessionIDFromHeader(ctx); sessionID != "" {
meta[coreexecutor.StickySessionMetadataKey] = sessionID
}
return meta
}

Expand Down Expand Up @@ -252,6 +257,21 @@ func executionSessionIDFromContext(ctx context.Context) string {
}
}

func stickySessionIDFromHeader(ctx context.Context) string {
if ctx == nil {
return ""
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil || ginCtx.Request == nil {
return ""
}
id := strings.TrimSpace(ginCtx.GetHeader("X-CLIProxyAPI-Session-ID"))
if len(id) > coreauth.StickyMaxSessionIDLen {
return ""
}
return id
}
Comment on lines +260 to +273
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The X-Session-ID header is extracted and used directly as a key in the stickyStore without any length validation or sanitization. This untrusted input contributes to the potential memory exhaustion vulnerability in the stickyStore.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Added 256-char length limit — headers exceeding it return empty (no binding created).


// BaseAPIHandler contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service and manages
// load balancing, client selection, and configuration.
Expand Down
115 changes: 103 additions & 12 deletions sdk/cliproxy/auth/conductor.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ type Manager struct {
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider

// sticky maintains session-to-auth bindings for sticky routing.
sticky *stickyStore

// Auto refresh state
refreshCancel context.CancelFunc
refreshSemaphore chan struct{}
Expand All @@ -181,6 +184,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
sticky: newStickyStore(),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
}
// atomic.Value requires non-nil initial value.
Expand Down Expand Up @@ -483,7 +487,7 @@ func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamC
}
}

func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk, stickyKey string) *cliproxyexecutor.StreamResult {
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
Expand All @@ -497,6 +501,12 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
// Sticky session: clear binding on mid-stream error.
if stickyKey != "" {
if sc := statusCodeFromResult(rerr); sc == 429 || sc >= 500 {
m.sticky.Delete(stickyKey)
}
}
}
if !forward {
return false
Expand Down Expand Up @@ -532,9 +542,9 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
}

func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, stickyKey string) (*cliproxyexecutor.StreamResult, bool, error) {
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
return nil, false, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
execModels := m.prepareExecutionModels(auth, routeModel)
var lastErr error
Expand All @@ -544,7 +554,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
if errStream != nil {
if errCtx := ctx.Err(); errCtx != nil {
return nil, errCtx
return nil, false, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
Expand All @@ -554,7 +564,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(ctx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
return nil, false, errStream
}
lastErr = errStream
continue
Expand All @@ -564,7 +574,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
if bootstrapErr != nil {
if errCtx := ctx.Err(); errCtx != nil {
discardStreamChunks(streamResult.Chunks)
return nil, errCtx
return nil, false, errCtx
}
if isRequestInvalidError(bootstrapErr) {
rerr := &Error{Message: bootstrapErr.Error()}
Expand All @@ -575,7 +585,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
return nil, bootstrapErr
return nil, false, bootstrapErr
}
if idx < len(execModels)-1 {
rerr := &Error{Message: bootstrapErr.Error()}
Expand All @@ -592,7 +602,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickyKey), false, nil
}

if closed && len(buffered) == 0 {
Expand All @@ -606,7 +616,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickyKey), false, nil
}

remaining := streamResult.Chunks
Expand All @@ -615,12 +625,12 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
close(closedCh)
remaining = closedCh
}
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining, stickyKey), true, nil
}
if lastErr == nil {
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
}
return nil, lastErr
return nil, false, lastErr
}

func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
Expand Down Expand Up @@ -978,6 +988,10 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)

stickySessionID := m.resolveStickySession(opts.Metadata, routeModel)
sk := stickyKey(stickySessionID, routeModel)

tried := make(map[string]struct{})
var lastErr error
for {
Expand Down Expand Up @@ -1025,12 +1039,24 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
// Sticky session: clear binding on rate-limit or server error so next
// request falls back to normal auth selection.
if sk != "" {
if sc := statusCodeFromResult(result.Error); sc == 429 || sc >= 500 {
m.sticky.Delete(sk)
delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey)
}
}
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
// Sticky session: bind session to the auth that succeeded.
if sk != "" {
m.sticky.Set(sk, auth.ID, stickySessionTTL)
}
m.MarkResult(execCtx, result)
return resp, nil
}
Expand Down Expand Up @@ -1122,6 +1148,10 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)

stickySessionID := m.resolveStickySession(opts.Metadata, routeModel)
sk := stickyKey(stickySessionID, routeModel)

tried := make(map[string]struct{})
var lastErr error
for {
Expand Down Expand Up @@ -1149,17 +1179,28 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
streamResult, streamOK, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, sk)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
// Sticky session: clear binding on rate-limit or server error.
if sk != "" {
if sc := statusCodeFromError(errStream); sc == 429 || sc >= 500 {
m.sticky.Delete(sk)
delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey)
}
}
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
// Sticky session: only bind on genuine stream success.
if sk != "" && streamOK {
m.sticky.Set(sk, auth.ID, stickySessionTTL)
}
return streamResult, nil
}
}
Expand Down Expand Up @@ -1221,6 +1262,50 @@ func pinnedAuthIDFromMetadata(meta map[string]any) string {
}
}

// stickySessionTTL is the duration a session-to-auth binding remains valid.
const stickySessionTTL = time.Hour

// stickyKey builds a composite store key from session ID and model so that
// the same session ID used with different models gets independent bindings.
func stickyKey(sessionID, model string) string {
return sessionID + "|" + model
}

func stickySessionIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[cliproxyexecutor.StickySessionMetadataKey]
if !ok || raw == nil {
return ""
}
switch val := raw.(type) {
case string:
return strings.TrimSpace(val)
case []byte:
return strings.TrimSpace(string(val))
default:
return ""
}
}

// resolveStickySession extracts the sticky session ID from metadata and, if a
// valid binding exists, sets the pinned auth ID. An explicit pinned_auth_id
// in the metadata takes precedence over any sticky binding.
// routeModel scopes the lookup so the same session ID with different models
// gets independent auth bindings.
func (m *Manager) resolveStickySession(meta map[string]any, routeModel string) string {
id := stickySessionIDFromMetadata(meta)
if id != "" {
if _, alreadyPinned := meta[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned {
if boundAuth, found := m.sticky.Get(stickyKey(id, routeModel)); found {
meta[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth
}
}
}
return id
}

func publishSelectedAuthMetadata(meta map[string]any, authID string) {
if len(meta) == 0 {
return
Expand Down Expand Up @@ -2331,13 +2416,19 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
lastStickyCleanup := time.Now()
const stickyCleanupInterval = 5 * time.Minute
m.checkRefreshes(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.checkRefreshes(ctx)
if time.Since(lastStickyCleanup) >= stickyCleanupInterval {
m.sticky.Cleanup()
lastStickyCleanup = time.Now()
}
}
}
}()
Expand Down
83 changes: 83 additions & 0 deletions sdk/cliproxy/auth/sticky.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package auth

import (
"sync"
"time"
)

// stickyStore maintains session-to-auth bindings so that requests carrying the
// same session ID are routed to the same auth/account. Entries expire after a
// configurable TTL and are garbage-collected by Cleanup.
//
// maxEntries caps the number of stored bindings to prevent memory exhaustion
// from untrusted X-CLIProxyAPI-Session-ID headers.
type stickyStore struct {
mu sync.RWMutex
entries map[string]stickyEntry
maxEntries int
}

type stickyEntry struct {
authID string
expiresAt time.Time
}

// stickyMaxEntries is the upper bound on stored session bindings.
const stickyMaxEntries = 10_000

// StickyMaxSessionIDLen limits the accepted X-CLIProxyAPI-Session-ID length.
const StickyMaxSessionIDLen = 256

func newStickyStore() *stickyStore {
return &stickyStore{entries: make(map[string]stickyEntry), maxEntries: stickyMaxEntries}
}

// Get returns the bound auth ID for the given session, if it exists and has not
// expired.
func (s *stickyStore) Get(sessionID string) (string, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
e, ok := s.entries[sessionID]
if !ok || time.Now().After(e.expiresAt) {
return "", false
}
return e.authID, true
}

// Set binds a session to an auth ID with the specified TTL.
// If the store is at capacity, the write is silently dropped.
func (s *stickyStore) Set(sessionID, authID string, ttl time.Duration) {
s.mu.Lock()
if _, exists := s.entries[sessionID]; !exists && len(s.entries) >= s.maxEntries {
s.mu.Unlock()
return
}
s.entries[sessionID] = stickyEntry{authID: authID, expiresAt: time.Now().Add(ttl)}
s.mu.Unlock()
Comment on lines +49 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The stickyStore uses an unbounded in-memory map to store session-to-auth bindings. Since the sessionID is derived from the user-controlled X-Session-ID header, an attacker can send a large number of requests with unique session IDs to exhaust the server's memory.

Furthermore, the Cleanup function (lines 52-61) iterates over the entire map while holding a write lock. If the map grows very large due to an attack, this cleanup process will block all other operations on the stickyStore for an extended period, leading to a denial of service.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Added maxEntries cap (10k) — new entries are silently dropped at capacity. Overwriting existing keys still works. Added TestStickyStore_MaxEntries covering both cases.

}

// Delete removes the binding for the given session ID.
func (s *stickyStore) Delete(sessionID string) {
s.mu.Lock()
delete(s.entries, sessionID)
s.mu.Unlock()
}

// Cleanup removes all expired entries.
func (s *stickyStore) Cleanup() {
now := time.Now()
s.mu.Lock()
for k, e := range s.entries {
if now.After(e.expiresAt) {
delete(s.entries, k)
}
}
s.mu.Unlock()
}

// Len returns the number of entries (including possibly-expired ones).
func (s *stickyStore) Len() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.entries)
}
Loading
Loading