-
-
Notifications
You must be signed in to change notification settings - Fork 2.8k
feat: sticky session routing via X-Session-ID header #1998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -160,6 +160,9 @@ type Manager struct { | |
| // Optional HTTP RoundTripper provider injected by host. | ||
| rtProvider RoundTripperProvider | ||
|
|
||
| // sticky maintains session-to-auth bindings for sticky routing. | ||
| sticky *stickyStore | ||
|
|
||
| // Auto refresh state | ||
| refreshCancel context.CancelFunc | ||
| refreshSemaphore chan struct{} | ||
|
|
@@ -181,6 +184,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { | |
| auths: make(map[string]*Auth), | ||
| providerOffsets: make(map[string]int), | ||
| modelPoolOffsets: make(map[string]int), | ||
| sticky: newStickyStore(), | ||
| refreshSemaphore: make(chan struct{}, refreshMaxConcurrency), | ||
| } | ||
| // atomic.Value requires non-nil initial value. | ||
|
|
@@ -483,7 +487,7 @@ func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamC | |
| } | ||
| } | ||
|
|
||
| func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult { | ||
| func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk, stickySessionID string) *cliproxyexecutor.StreamResult { | ||
| out := make(chan cliproxyexecutor.StreamChunk) | ||
| go func() { | ||
| defer close(out) | ||
|
|
@@ -497,6 +501,12 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro | |
| rerr.HTTPStatus = se.StatusCode() | ||
| } | ||
| m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}) | ||
| // Sticky session: clear binding on mid-stream error. | ||
| if stickySessionID != "" { | ||
| if sc := statusCodeFromResult(rerr); sc == 429 || sc >= 500 { | ||
| m.sticky.Delete(stickySessionID) | ||
| } | ||
| } | ||
| } | ||
| if !forward { | ||
| return false | ||
|
|
@@ -532,7 +542,7 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro | |
| return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out} | ||
| } | ||
|
|
||
| func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) { | ||
| func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, stickySessionID string) (*cliproxyexecutor.StreamResult, error) { | ||
| if executor == nil { | ||
| return nil, &Error{Code: "executor_not_found", Message: "executor not registered"} | ||
| } | ||
|
|
@@ -592,7 +602,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi | |
| errCh := make(chan cliproxyexecutor.StreamChunk, 1) | ||
| errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr} | ||
| close(errCh) | ||
| return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil | ||
| return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickySessionID), nil | ||
| } | ||
|
|
||
| if closed && len(buffered) == 0 { | ||
|
|
@@ -606,7 +616,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi | |
| errCh := make(chan cliproxyexecutor.StreamChunk, 1) | ||
| errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr} | ||
| close(errCh) | ||
| return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil | ||
| return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickySessionID), nil | ||
| } | ||
|
|
||
| remaining := streamResult.Chunks | ||
|
|
@@ -615,7 +625,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi | |
| close(closedCh) | ||
| remaining = closedCh | ||
| } | ||
| return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil | ||
| return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining, stickySessionID), nil | ||
| } | ||
| if lastErr == nil { | ||
| lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"} | ||
|
|
@@ -978,6 +988,18 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req | |
| } | ||
| routeModel := req.Model | ||
| opts = ensureRequestedModelMetadata(opts, routeModel) | ||
|
|
||
| // Sticky session: resolve session→auth binding before pick loop. | ||
| // An explicit pinned_auth_id takes precedence over sticky binding. | ||
| stickySessionID := stickySessionIDFromMetadata(opts.Metadata) | ||
| if stickySessionID != "" { | ||
| if _, alreadyPinned := opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned { | ||
| if boundAuth, found := m.sticky.Get(stickySessionID); found { | ||
| opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth | ||
| } | ||
| } | ||
| } | ||
|
||
|
|
||
| tried := make(map[string]struct{}) | ||
| var lastErr error | ||
| for { | ||
|
|
@@ -1025,12 +1047,24 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req | |
| result.RetryAfter = ra | ||
| } | ||
| m.MarkResult(execCtx, result) | ||
| // Sticky session: clear binding on rate-limit or server error so next | ||
| // request falls back to normal auth selection. | ||
| if stickySessionID != "" { | ||
| if sc := statusCodeFromResult(result.Error); sc == 429 || sc >= 500 { | ||
| m.sticky.Delete(stickySessionID) | ||
| delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey) | ||
| } | ||
| } | ||
| if isRequestInvalidError(errExec) { | ||
| return cliproxyexecutor.Response{}, errExec | ||
| } | ||
| authErr = errExec | ||
| continue | ||
| } | ||
| // Sticky session: bind session to the auth that succeeded. | ||
| if stickySessionID != "" { | ||
| m.sticky.Set(stickySessionID, auth.ID, stickySessionTTL) | ||
| } | ||
| m.MarkResult(execCtx, result) | ||
| return resp, nil | ||
| } | ||
|
|
@@ -1122,6 +1156,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string | |
| } | ||
| routeModel := req.Model | ||
| opts = ensureRequestedModelMetadata(opts, routeModel) | ||
|
|
||
| // Sticky session: resolve session→auth binding before pick loop. | ||
| // An explicit pinned_auth_id takes precedence over sticky binding. | ||
| stickySessionID := stickySessionIDFromMetadata(opts.Metadata) | ||
| if stickySessionID != "" { | ||
| if _, alreadyPinned := opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned { | ||
| if boundAuth, found := m.sticky.Get(stickySessionID); found { | ||
| opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth | ||
| } | ||
| } | ||
| } | ||
|
|
||
| tried := make(map[string]struct{}) | ||
| var lastErr error | ||
| for { | ||
|
|
@@ -1149,17 +1195,32 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string | |
| execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) | ||
| execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) | ||
| } | ||
| streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel) | ||
| streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, stickySessionID) | ||
| if errStream != nil { | ||
| if errCtx := execCtx.Err(); errCtx != nil { | ||
| return nil, errCtx | ||
| } | ||
| // Sticky session: clear binding on rate-limit or server error. | ||
| if stickySessionID != "" { | ||
| sc := 0 | ||
| if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { | ||
| sc = se.StatusCode() | ||
| } | ||
| if sc == 429 || sc >= 500 { | ||
|
||
| m.sticky.Delete(stickySessionID) | ||
| delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey) | ||
| } | ||
| } | ||
| if isRequestInvalidError(errStream) { | ||
| return nil, errStream | ||
| } | ||
| lastErr = errStream | ||
| continue | ||
| } | ||
| // Sticky session: bind session to the auth that started streaming. | ||
| if stickySessionID != "" { | ||
| m.sticky.Set(stickySessionID, auth.ID, stickySessionTTL) | ||
| } | ||
| return streamResult, nil | ||
| } | ||
| } | ||
|
|
@@ -1221,6 +1282,27 @@ func pinnedAuthIDFromMetadata(meta map[string]any) string { | |
| } | ||
| } | ||
|
|
||
| // stickySessionTTL is the duration a session-to-auth binding remains valid. | ||
| const stickySessionTTL = time.Hour | ||
|
|
||
| func stickySessionIDFromMetadata(meta map[string]any) string { | ||
| if len(meta) == 0 { | ||
| return "" | ||
| } | ||
| raw, ok := meta[cliproxyexecutor.StickySessionMetadataKey] | ||
| if !ok || raw == nil { | ||
| return "" | ||
| } | ||
| switch val := raw.(type) { | ||
| case string: | ||
| return strings.TrimSpace(val) | ||
| case []byte: | ||
| return strings.TrimSpace(string(val)) | ||
| default: | ||
| return "" | ||
| } | ||
| } | ||
|
|
||
| func publishSelectedAuthMetadata(meta map[string]any, authID string) { | ||
| if len(meta) == 0 { | ||
| return | ||
|
|
@@ -2331,13 +2413,19 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio | |
| go func() { | ||
| ticker := time.NewTicker(interval) | ||
| defer ticker.Stop() | ||
| stickyCleanupCounter := 0 | ||
| m.checkRefreshes(ctx) | ||
| for { | ||
| select { | ||
| case <-ctx.Done(): | ||
| return | ||
| case <-ticker.C: | ||
| m.checkRefreshes(ctx) | ||
| stickyCleanupCounter++ | ||
| if stickyCleanupCounter >= 60 { // ~every 5 min at default 5s interval | ||
| m.sticky.Cleanup() | ||
| stickyCleanupCounter = 0 | ||
| } | ||
|
||
| } | ||
| } | ||
| }() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| package auth | ||
|
|
||
| import ( | ||
| "sync" | ||
| "time" | ||
| ) | ||
|
|
||
| // stickyStore maintains session-to-auth bindings so that requests carrying the | ||
| // same session ID are routed to the same auth/account. Entries expire after a | ||
| // configurable TTL and are garbage-collected by Cleanup. | ||
| type stickyStore struct { | ||
| mu sync.RWMutex | ||
| entries map[string]stickyEntry | ||
| } | ||
|
|
||
| type stickyEntry struct { | ||
| authID string | ||
| expiresAt time.Time | ||
| } | ||
|
|
||
| func newStickyStore() *stickyStore { | ||
| return &stickyStore{entries: make(map[string]stickyEntry)} | ||
| } | ||
|
|
||
| // Get returns the bound auth ID for the given session, if it exists and has not | ||
| // expired. | ||
| func (s *stickyStore) Get(sessionID string) (string, bool) { | ||
| s.mu.RLock() | ||
| defer s.mu.RUnlock() | ||
| e, ok := s.entries[sessionID] | ||
| if !ok || time.Now().After(e.expiresAt) { | ||
| return "", false | ||
| } | ||
| return e.authID, true | ||
| } | ||
|
|
||
| // Set binds a session to an auth ID with the specified TTL. | ||
| func (s *stickyStore) Set(sessionID, authID string, ttl time.Duration) { | ||
| s.mu.Lock() | ||
| s.entries[sessionID] = stickyEntry{authID: authID, expiresAt: time.Now().Add(ttl)} | ||
| s.mu.Unlock() | ||
|
Comment on lines
+49
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Furthermore, the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. Added |
||
| } | ||
|
|
||
| // Delete removes the binding for the given session ID. | ||
| func (s *stickyStore) Delete(sessionID string) { | ||
| s.mu.Lock() | ||
| delete(s.entries, sessionID) | ||
| s.mu.Unlock() | ||
| } | ||
|
|
||
| // Cleanup removes all expired entries. | ||
| func (s *stickyStore) Cleanup() { | ||
| now := time.Now() | ||
| s.mu.Lock() | ||
| for k, e := range s.entries { | ||
| if now.After(e.expiresAt) { | ||
| delete(s.entries, k) | ||
| } | ||
| } | ||
| s.mu.Unlock() | ||
| } | ||
|
|
||
| // Len returns the number of entries (including possibly-expired ones). | ||
| func (s *stickyStore) Len() int { | ||
| s.mu.RLock() | ||
| defer s.mu.RUnlock() | ||
| return len(s.entries) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
X-Session-IDheader is extracted and used directly as a key in thestickyStorewithout any length validation or sanitization. This untrusted input contributes to the potential memory exhaustion vulnerability in thestickyStore.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Added 256-char length limit — headers exceeding it return empty (no binding created).