-
-
Notifications
You must be signed in to change notification settings - Fork 3.1k
fix: 修复 Codex WebSocket 会话在额度耗尽后不切换账号 #2256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
cc5cb31
32e8d1d
b9b9f77
ae8557e
42acfa7
e9a98a3
51ea8f1
7e7cc1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -117,6 +117,59 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { | |
| s.activeMu.Unlock() | ||
| } | ||
|
|
||
| func (s *codexWebsocketSession) isCurrentConn(conn *websocket.Conn) bool { | ||
| if s == nil || conn == nil { | ||
| return false | ||
| } | ||
| s.connMu.Lock() | ||
| current := s.conn | ||
| s.connMu.Unlock() | ||
| return current == conn | ||
| } | ||
|
|
||
| func (s *codexWebsocketSession) activeSnapshotForCurrentConn(conn *websocket.Conn) (chan codexWebsocketRead, <-chan struct{}, bool) { | ||
| if s == nil || conn == nil { | ||
| return nil, nil, false | ||
| } | ||
| s.connMu.Lock() | ||
| if s.conn != conn { | ||
| s.connMu.Unlock() | ||
| return nil, nil, false | ||
| } | ||
| s.activeMu.Lock() | ||
| ch := s.activeCh | ||
| done := s.activeDone | ||
| s.activeMu.Unlock() | ||
| s.connMu.Unlock() | ||
| return ch, done, true | ||
| } | ||
|
|
||
| func (s *codexWebsocketSession) clearActiveForCurrentConn(conn *websocket.Conn, ch chan codexWebsocketRead) bool { | ||
| if s == nil || conn == nil || ch == nil { | ||
| return false | ||
| } | ||
| s.connMu.Lock() | ||
| if s.conn != conn { | ||
| s.connMu.Unlock() | ||
| return false | ||
| } | ||
| s.activeMu.Lock() | ||
| if s.activeCh != ch { | ||
| s.activeMu.Unlock() | ||
| s.connMu.Unlock() | ||
| return false | ||
| } | ||
| s.activeCh = nil | ||
| if s.activeCancel != nil { | ||
| s.activeCancel() | ||
| } | ||
| s.activeCancel = nil | ||
| s.activeDone = nil | ||
| s.activeMu.Unlock() | ||
| s.connMu.Unlock() | ||
| return true | ||
| } | ||
|
|
||
| func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { | ||
| if s == nil { | ||
| return fmt.Errorf("codex websockets executor: session is nil") | ||
|
|
@@ -1067,12 +1120,21 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth * | |
| if sess == nil { | ||
| return e.dialCodexWebsocket(ctx, auth, wsURL, headers) | ||
| } | ||
| authID = strings.TrimSpace(authID) | ||
|
|
||
| sess.connMu.Lock() | ||
| conn := sess.conn | ||
| readerConn := sess.readerConn | ||
| currentAuthID := strings.TrimSpace(sess.authID) | ||
| sess.connMu.Unlock() | ||
| if conn != nil && currentAuthID != authID { | ||
| // 账号切换时先断开旧连接避免继续复用旧账号 | ||
| e.invalidateUpstreamConn(sess, conn, "auth_switched", nil) | ||
| conn = nil | ||
| readerConn = nil | ||
| } | ||
| if conn != nil { | ||
| // 账号未变化时复用连接减少不必要重连 | ||
| if readerConn != conn { | ||
| sess.connMu.Lock() | ||
| sess.readerConn = conn | ||
|
|
@@ -1114,21 +1176,28 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, | |
| return | ||
| } | ||
| for { | ||
| if !sess.isCurrentConn(conn) { | ||
| // 旧连接读循环直接退出避免误伤新请求通道 | ||
| return | ||
| } | ||
| _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) | ||
| msgType, payload, errRead := conn.ReadMessage() | ||
| if errRead != nil { | ||
| sess.activeMu.Lock() | ||
| ch := sess.activeCh | ||
| done := sess.activeDone | ||
| sess.activeMu.Unlock() | ||
| // 在同一临界区做归属校验和通道快照避免检查后竞态 | ||
| ch, done, current := sess.activeSnapshotForCurrentConn(conn) | ||
| if !current { | ||
| // 旧连接读错时不触碰当前活跃通道 | ||
| return | ||
| } | ||
|
Comment on lines
+1222
to
+1225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do not return immediately on Useful? React with 👍 / 👎. |
||
| if ch != nil { | ||
| select { | ||
| case ch <- codexWebsocketRead{conn: conn, err: errRead}: | ||
| case <-done: | ||
| default: | ||
| } | ||
| sess.clearActive(ch) | ||
| close(ch) | ||
| if sess.clearActiveForCurrentConn(conn, ch) { | ||
| close(ch) | ||
| } | ||
| } | ||
| e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead) | ||
| return | ||
|
|
@@ -1137,29 +1206,33 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, | |
| if msgType != websocket.TextMessage { | ||
| if msgType == websocket.BinaryMessage { | ||
| errBinary := fmt.Errorf("codex websockets executor: unexpected binary message") | ||
| sess.activeMu.Lock() | ||
| ch := sess.activeCh | ||
| done := sess.activeDone | ||
| sess.activeMu.Unlock() | ||
| // 在同一临界区做归属校验和通道快照避免检查后竞态 | ||
| ch, done, current := sess.activeSnapshotForCurrentConn(conn) | ||
| if !current { | ||
| // 旧连接二进制异常时不触碰当前活跃通道 | ||
| return | ||
| } | ||
| if ch != nil { | ||
| select { | ||
| case ch <- codexWebsocketRead{conn: conn, err: errBinary}: | ||
| case <-done: | ||
| default: | ||
| } | ||
| sess.clearActive(ch) | ||
| close(ch) | ||
| if sess.clearActiveForCurrentConn(conn, ch) { | ||
| close(ch) | ||
| } | ||
| } | ||
| e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary) | ||
| return | ||
| } | ||
| continue | ||
| } | ||
|
|
||
| sess.activeMu.Lock() | ||
| ch := sess.activeCh | ||
| done := sess.activeDone | ||
| sess.activeMu.Unlock() | ||
| // 在同一临界区做归属校验和通道快照避免检查后竞态 | ||
| ch, done, current := sess.activeSnapshotForCurrentConn(conn) | ||
| if !current { | ||
| // 旧连接消息不再分发给新连接请求 | ||
| return | ||
| } | ||
| if ch == nil { | ||
| continue | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { | |
| var lastRequest []byte | ||
| lastResponseOutput := []byte("[]") | ||
| pinnedAuthID := "" | ||
| forceDisableIncrementalAfterAuthReset := false | ||
|
|
||
| for { | ||
| msgType, payload, errReadMessage := conn.ReadMessage() | ||
|
|
@@ -107,16 +108,18 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { | |
| appendWebsocketEvent(&wsBodyLog, "request", payload) | ||
|
|
||
| allowIncrementalInputWithPreviousResponseID := false | ||
| if pinnedAuthID != "" && h != nil && h.AuthManager != nil { | ||
| if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { | ||
| allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) | ||
| } | ||
| } else { | ||
| requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) | ||
| if requestModelName == "" { | ||
| requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) | ||
| if !forceDisableIncrementalAfterAuthReset { | ||
| if pinnedAuthID != "" && h != nil && h.AuthManager != nil { | ||
| if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { | ||
| allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) | ||
| } | ||
| } else { | ||
| requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) | ||
| if requestModelName == "" { | ||
| requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) | ||
| } | ||
| allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) | ||
| } | ||
| allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) | ||
| } | ||
|
|
||
| var requestJSON []byte | ||
|
|
@@ -151,6 +154,28 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { | |
| } | ||
| continue | ||
| } | ||
| nextSessionRequestSnapshot := updatedLastRequest | ||
| if shouldBuildResponsesWebsocketFullSnapshot(requestJSON, allowIncrementalInputWithPreviousResponseID) { | ||
| _, shadowLastRequest, shadowErr := normalizeResponsesWebsocketRequestWithMode( | ||
| payload, | ||
| lastRequest, | ||
| lastResponseOutput, | ||
| false, | ||
| ) | ||
| if shadowErr != nil { | ||
| // 影子快照失败时保留旧快照避免污染会话状态 | ||
| nextSessionRequestSnapshot = lastRequest | ||
| log.Warnf( | ||
| "responses websocket: keep previous snapshot id=%s status=%d error=%v", | ||
| passthroughSessionID, | ||
| shadowErr.StatusCode, | ||
| shadowErr.Error, | ||
| ) | ||
| } else { | ||
| // 增量模式只发 delta 同时维护完整快照供切号恢复 | ||
| nextSessionRequestSnapshot = shadowLastRequest | ||
| } | ||
| } | ||
| if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) { | ||
| if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil { | ||
| requestJSON = updated | ||
|
|
@@ -167,7 +192,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { | |
| } | ||
| continue | ||
| } | ||
| lastRequest = updatedLastRequest | ||
| // lastRequest 始终保存完整 transcript 快照 | ||
| lastRequest = nextSessionRequestSnapshot | ||
|
||
|
|
||
| modelName := gjson.GetBytes(requestJSON, "model").String() | ||
| cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) | ||
|
|
@@ -192,13 +218,26 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { | |
| } | ||
| dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") | ||
|
|
||
| completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) | ||
| completedOutput, terminalStatus, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) | ||
| if errForward != nil { | ||
| wsTerminateErr = errForward | ||
| appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error())) | ||
| log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) | ||
| return | ||
| } | ||
| if shouldResetResponsesWebsocketAuthPin(terminalStatus) { | ||
| // 限额错误后解除 pin 让后续请求重新选可用账号 | ||
| pinnedAuthID = "" | ||
| // 切号恢复阶段先禁用增量模式避免沿用旧账号 response id | ||
| forceDisableIncrementalAfterAuthReset = true | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Setting Useful? React with 👍 / 👎. |
||
| if h != nil && h.AuthManager != nil { | ||
|
Comment on lines
+225
to
+231
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
After a quota-class terminal error this block clears Useful? React with 👍 / 👎. |
||
| // 主动关闭旧上游会话避免继续复用旧账号连接 | ||
| h.AuthManager.CloseExecutionSession(passthroughSessionID) | ||
| } | ||
| } else if forceDisableIncrementalAfterAuthReset && terminalStatus == 0 { | ||
| // 新账号完成一轮后恢复增量模式 | ||
| forceDisableIncrementalAfterAuthReset = false | ||
| } | ||
| lastResponseOutput = completedOutput | ||
| } | ||
| } | ||
|
|
@@ -488,6 +527,14 @@ func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest [] | |
| return generateResult.Exists() && !generateResult.Bool() | ||
| } | ||
|
|
||
| func shouldBuildResponsesWebsocketFullSnapshot(normalizedRequestJSON []byte, allowIncrementalInputWithPreviousResponseID bool) bool { | ||
| if !allowIncrementalInputWithPreviousResponseID { | ||
| return false | ||
| } | ||
| prev := strings.TrimSpace(gjson.GetBytes(normalizedRequestJSON, "previous_response_id").String()) | ||
| return prev != "" | ||
| } | ||
|
|
||
| func writeResponsesWebsocketSyntheticPrewarm( | ||
| c *gin.Context, | ||
| conn *websocket.Conn, | ||
|
|
@@ -610,21 +657,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( | |
| errs <-chan *interfaces.ErrorMessage, | ||
| wsBodyLog *strings.Builder, | ||
| sessionID string, | ||
| ) ([]byte, error) { | ||
| ) ([]byte, int, error) { | ||
| completed := false | ||
| completedOutput := []byte("[]") | ||
| terminalStatusCode := 0 | ||
|
|
||
| for { | ||
| select { | ||
| case <-c.Request.Context().Done(): | ||
| cancel(c.Request.Context().Err()) | ||
| return completedOutput, c.Request.Context().Err() | ||
| return completedOutput, terminalStatusCode, c.Request.Context().Err() | ||
| case errMsg, ok := <-errs: | ||
| if !ok { | ||
| errs = nil | ||
| continue | ||
| } | ||
| if errMsg != nil { | ||
| terminalStatusCode = errMsg.StatusCode | ||
| h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) | ||
| markAPIResponseTimestamp(c) | ||
| errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) | ||
|
|
@@ -644,22 +693,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( | |
| // errWrite, | ||
| // ) | ||
| cancel(errMsg.Error) | ||
| return completedOutput, errWrite | ||
| return completedOutput, terminalStatusCode, errWrite | ||
| } | ||
| } | ||
| if errMsg != nil { | ||
| cancel(errMsg.Error) | ||
| } else { | ||
| cancel(nil) | ||
| } | ||
| return completedOutput, nil | ||
| return completedOutput, terminalStatusCode, nil | ||
| case chunk, ok := <-data: | ||
| if !ok { | ||
| if !completed { | ||
| errMsg := &interfaces.ErrorMessage{ | ||
| StatusCode: http.StatusRequestTimeout, | ||
| Error: fmt.Errorf("stream closed before response.completed"), | ||
| } | ||
| terminalStatusCode = errMsg.StatusCode | ||
| h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) | ||
| markAPIResponseTimestamp(c) | ||
| errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) | ||
|
|
@@ -679,13 +729,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( | |
| errWrite, | ||
| ) | ||
| cancel(errMsg.Error) | ||
| return completedOutput, errWrite | ||
| return completedOutput, terminalStatusCode, errWrite | ||
| } | ||
| cancel(errMsg.Error) | ||
| return completedOutput, nil | ||
| return completedOutput, terminalStatusCode, nil | ||
| } | ||
| cancel(nil) | ||
| return completedOutput, nil | ||
| return completedOutput, terminalStatusCode, nil | ||
| } | ||
|
|
||
| payloads := websocketJSONPayloadsFromChunk(chunk) | ||
|
|
@@ -712,13 +762,22 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( | |
| errWrite, | ||
| ) | ||
| cancel(errWrite) | ||
| return completedOutput, errWrite | ||
| return completedOutput, terminalStatusCode, errWrite | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| func shouldResetResponsesWebsocketAuthPin(statusCode int) bool { | ||
| switch statusCode { | ||
| case http.StatusTooManyRequests, http.StatusForbidden, http.StatusPaymentRequired: | ||
| return true | ||
| default: | ||
| return false | ||
| } | ||
| } | ||
|
|
||
| func responseCompletedOutputFromPayload(payload []byte) []byte { | ||
| output := gjson.GetBytes(payload, "response.output") | ||
| if output.Exists() && output.IsArray() { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling
invalidateUpstreamConnduring auth changes closes the old socket immediately, but the oldreadUpstreamLoopcan still wake later and run its error path, which unconditionallyclearActive/closes whatever channel is currently active for the session. If a new request has already installed its ownactiveCh, that channel gets closed andreadCodexWebsocketMessagereturnssession read channel closed, aborting a healthy post-switch request. This race is specific to in-session auth switching and can break the new quota-recovery flow intermittently.Useful? React with 👍 / 👎.