Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sdk/api/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
}
// Sticky session: forward X-Session-ID header so the conductor can pin
// subsequent requests from the same session to the same auth account.
if sessionID := stickySessionIDFromHeader(ctx); sessionID != "" {
meta[coreexecutor.StickySessionMetadataKey] = sessionID
}
return meta
}

Expand Down Expand Up @@ -252,6 +257,17 @@ func executionSessionIDFromContext(ctx context.Context) string {
}
}

func stickySessionIDFromHeader(ctx context.Context) string {
if ctx == nil {
return ""
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil || ginCtx.Request == nil {
return ""
}
return strings.TrimSpace(ginCtx.GetHeader("X-Session-ID"))
}
Comment on lines +260 to +273
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The X-Session-ID header is extracted and used directly as a key in the stickyStore without any length validation or sanitization. This untrusted input contributes to the potential memory exhaustion vulnerability in the stickyStore.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Added 256-char length limit — headers exceeding it return empty (no binding created).


// BaseAPIHandler contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service and manages
// load balancing, client selection, and configuration.
Expand Down
100 changes: 94 additions & 6 deletions sdk/cliproxy/auth/conductor.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ type Manager struct {
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider

// sticky maintains session-to-auth bindings for sticky routing.
sticky *stickyStore

// Auto refresh state
refreshCancel context.CancelFunc
refreshSemaphore chan struct{}
Expand All @@ -181,6 +184,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
sticky: newStickyStore(),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
}
// atomic.Value requires non-nil initial value.
Expand Down Expand Up @@ -483,7 +487,7 @@ func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamC
}
}

func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk, stickySessionID string) *cliproxyexecutor.StreamResult {
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
Expand All @@ -497,6 +501,12 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
// Sticky session: clear binding on mid-stream error.
if stickySessionID != "" {
if sc := statusCodeFromResult(rerr); sc == 429 || sc >= 500 {
m.sticky.Delete(stickySessionID)
}
}
}
if !forward {
return false
Expand Down Expand Up @@ -532,7 +542,7 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
}

func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, stickySessionID string) (*cliproxyexecutor.StreamResult, error) {
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
Expand Down Expand Up @@ -592,7 +602,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickySessionID), nil
}

if closed && len(buffered) == 0 {
Expand All @@ -606,7 +616,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickySessionID), nil
}

remaining := streamResult.Chunks
Expand All @@ -615,7 +625,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
close(closedCh)
remaining = closedCh
}
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining, stickySessionID), nil
}
if lastErr == nil {
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
Expand Down Expand Up @@ -978,6 +988,18 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)

// Sticky session: resolve session→auth binding before pick loop.
// An explicit pinned_auth_id takes precedence over sticky binding.
stickySessionID := stickySessionIDFromMetadata(opts.Metadata)
if stickySessionID != "" {
if _, alreadyPinned := opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned {
if boundAuth, found := m.sticky.Get(stickySessionID); found {
opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of logic for resolving sticky sessions is duplicated in executeStreamMixedOnce (lines 1160-1169). To improve maintainability and reduce code duplication, consider extracting this logic into a private helper method on the Manager.

For example:

func (m *Manager) resolveStickySession(meta map[string]any) string {
    stickySessionID := stickySessionIDFromMetadata(meta)
    if stickySessionID != "" {
        if _, alreadyPinned := meta[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned {
            if boundAuth, found := m.sticky.Get(stickySessionID); found {
                meta[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth
            }
        }
    }
    return stickySessionID
}

You could then call this helper in both executeMixedOnce and executeStreamMixedOnce like so: stickySessionID := m.resolveStickySession(opts.Metadata).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted resolveStickySession helper — both executeMixedOnce and executeStreamMixedOnce now call it.


tried := make(map[string]struct{})
var lastErr error
for {
Expand Down Expand Up @@ -1025,12 +1047,24 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
// Sticky session: clear binding on rate-limit or server error so next
// request falls back to normal auth selection.
if stickySessionID != "" {
if sc := statusCodeFromResult(result.Error); sc == 429 || sc >= 500 {
m.sticky.Delete(stickySessionID)
delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey)
}
}
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
// Sticky session: bind session to the auth that succeeded.
if stickySessionID != "" {
m.sticky.Set(stickySessionID, auth.ID, stickySessionTTL)
}
m.MarkResult(execCtx, result)
return resp, nil
}
Expand Down Expand Up @@ -1122,6 +1156,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)

// Sticky session: resolve session→auth binding before pick loop.
// An explicit pinned_auth_id takes precedence over sticky binding.
stickySessionID := stickySessionIDFromMetadata(opts.Metadata)
if stickySessionID != "" {
if _, alreadyPinned := opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned {
if boundAuth, found := m.sticky.Get(stickySessionID); found {
opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth
}
}
}

tried := make(map[string]struct{})
var lastErr error
for {
Expand Down Expand Up @@ -1149,17 +1195,32 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, stickySessionID)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
// Sticky session: clear binding on rate-limit or server error.
if stickySessionID != "" {
sc := 0
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
sc = se.StatusCode()
}
if sc == 429 || sc >= 500 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to extract the status code from an error appears to be a re-implementation of the existing statusCodeFromError helper function. Using the helper function directly would improve consistency and readability.

if sc := statusCodeFromError(errStream); sc == 429 || sc >= 500 {

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — switched to statusCodeFromError(errStream).

m.sticky.Delete(stickySessionID)
delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey)
}
}
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
// Sticky session: bind session to the auth that started streaming.
if stickySessionID != "" {
m.sticky.Set(stickySessionID, auth.ID, stickySessionTTL)
}
return streamResult, nil
}
}
Expand Down Expand Up @@ -1221,6 +1282,27 @@ func pinnedAuthIDFromMetadata(meta map[string]any) string {
}
}

// stickySessionTTL is the duration a session-to-auth binding remains valid.
const stickySessionTTL = time.Hour

func stickySessionIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[cliproxyexecutor.StickySessionMetadataKey]
if !ok || raw == nil {
return ""
}
switch val := raw.(type) {
case string:
return strings.TrimSpace(val)
case []byte:
return strings.TrimSpace(string(val))
default:
return ""
}
}

func publishSelectedAuthMetadata(meta map[string]any, authID string) {
if len(meta) == 0 {
return
Expand Down Expand Up @@ -2331,13 +2413,19 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
stickyCleanupCounter := 0
m.checkRefreshes(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.checkRefreshes(ctx)
stickyCleanupCounter++
if stickyCleanupCounter >= 60 { // ~every 5 min at default 5s interval
m.sticky.Cleanup()
stickyCleanupCounter = 0
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for triggering m.sticky.Cleanup() relies on a counter and a magic number 60. As the comment indicates, this is tied to the default 5s interval and is brittle if the interval parameter changes. A more robust approach would be to base the cleanup on elapsed time, making it independent of the ticker's interval.

For example:

// Before the loop
lastStickyCleanup := time.Now()
const stickyCleanupInterval = 5 * time.Minute

// Inside the loop
if time.Since(lastStickyCleanup) >= stickyCleanupInterval {
    m.sticky.Cleanup()
    lastStickyCleanup = time.Now()
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced counter with time.Since(lastStickyCleanup) — now interval-independent.

}
}
}()
Expand Down
68 changes: 68 additions & 0 deletions sdk/cliproxy/auth/sticky.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package auth

import (
"sync"
"time"
)

// stickyStore maintains session-to-auth bindings so that requests carrying the
// same session ID are routed to the same auth/account. Entries expire after a
// configurable TTL and are garbage-collected by Cleanup.
type stickyStore struct {
mu sync.RWMutex
entries map[string]stickyEntry
}

type stickyEntry struct {
authID string
expiresAt time.Time
}

func newStickyStore() *stickyStore {
return &stickyStore{entries: make(map[string]stickyEntry)}
}

// Get returns the bound auth ID for the given session, if it exists and has not
// expired.
func (s *stickyStore) Get(sessionID string) (string, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
e, ok := s.entries[sessionID]
if !ok || time.Now().After(e.expiresAt) {
return "", false
}
return e.authID, true
}

// Set binds a session to an auth ID with the specified TTL.
func (s *stickyStore) Set(sessionID, authID string, ttl time.Duration) {
s.mu.Lock()
s.entries[sessionID] = stickyEntry{authID: authID, expiresAt: time.Now().Add(ttl)}
s.mu.Unlock()
Comment on lines +49 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The stickyStore uses an unbounded in-memory map to store session-to-auth bindings. Since the sessionID is derived from the user-controlled X-Session-ID header, an attacker can send a large number of requests with unique session IDs to exhaust the server's memory.

Furthermore, the Cleanup function (lines 52-61) iterates over the entire map while holding a write lock. If the map grows very large due to an attack, this cleanup process will block all other operations on the stickyStore for an extended period, leading to a denial of service.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Added maxEntries cap (10k) — new entries are silently dropped at capacity. Overwriting existing keys still works. Added TestStickyStore_MaxEntries covering both cases.

}

// Delete removes the binding for the given session ID.
func (s *stickyStore) Delete(sessionID string) {
s.mu.Lock()
delete(s.entries, sessionID)
s.mu.Unlock()
}

// Cleanup removes all expired entries.
func (s *stickyStore) Cleanup() {
now := time.Now()
s.mu.Lock()
for k, e := range s.entries {
if now.After(e.expiresAt) {
delete(s.entries, k)
}
}
s.mu.Unlock()
}

// Len returns the number of entries (including possibly-expired ones).
func (s *stickyStore) Len() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.entries)
}
Loading