-
-
Notifications
You must be signed in to change notification settings - Fork 3.1k
fix: 修复 Codex WebSocket 会话在额度耗尽后不切换账号 #2256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
cc5cb31
32e8d1d
b9b9f77
ae8557e
42acfa7
e9a98a3
51ea8f1
7e7cc1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -117,6 +117,16 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { | |
| s.activeMu.Unlock() | ||
| } | ||
|
|
||
| func (s *codexWebsocketSession) isCurrentConn(conn *websocket.Conn) bool { | ||
| if s == nil || conn == nil { | ||
| return false | ||
| } | ||
| s.connMu.Lock() | ||
| current := s.conn | ||
| s.connMu.Unlock() | ||
| return current == conn | ||
| } | ||
|
|
||
| func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { | ||
| if s == nil { | ||
| return fmt.Errorf("codex websockets executor: session is nil") | ||
|
|
@@ -1067,12 +1077,21 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth * | |
| if sess == nil { | ||
| return e.dialCodexWebsocket(ctx, auth, wsURL, headers) | ||
| } | ||
| authID = strings.TrimSpace(authID) | ||
|
|
||
| sess.connMu.Lock() | ||
| conn := sess.conn | ||
| readerConn := sess.readerConn | ||
| currentAuthID := strings.TrimSpace(sess.authID) | ||
| sess.connMu.Unlock() | ||
| if conn != nil && currentAuthID != authID { | ||
| // 账号切换时先断开旧连接避免继续复用旧账号 | ||
| e.invalidateUpstreamConn(sess, conn, "auth_switched", nil) | ||
| conn = nil | ||
| readerConn = nil | ||
| } | ||
| if conn != nil { | ||
| // 账号未变化时复用连接减少不必要重连 | ||
| if readerConn != conn { | ||
| sess.connMu.Lock() | ||
| sess.readerConn = conn | ||
|
|
@@ -1114,9 +1133,17 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, | |
| return | ||
| } | ||
| for { | ||
| if !sess.isCurrentConn(conn) { | ||
| // 旧连接读循环直接退出避免误伤新请求通道 | ||
| return | ||
| } | ||
| _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) | ||
| msgType, payload, errRead := conn.ReadMessage() | ||
| if errRead != nil { | ||
| if !sess.isCurrentConn(conn) { | ||
| // 旧连接读错时不触碰当前活跃通道 | ||
| return | ||
| } | ||
|
Comment on lines
+1222
to
+1225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do not return immediately on Useful? React with 👍 / 👎. |
||
| sess.activeMu.Lock() | ||
| ch := sess.activeCh | ||
| done := sess.activeDone | ||
|
|
@@ -1137,6 +1164,10 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, | |
| if msgType != websocket.TextMessage { | ||
| if msgType == websocket.BinaryMessage { | ||
| errBinary := fmt.Errorf("codex websockets executor: unexpected binary message") | ||
| if !sess.isCurrentConn(conn) { | ||
| // 旧连接二进制异常时不触碰当前活跃通道 | ||
| return | ||
| } | ||
| sess.activeMu.Lock() | ||
| ch := sess.activeCh | ||
| done := sess.activeDone | ||
|
|
@@ -1155,6 +1186,10 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, | |
| } | ||
| continue | ||
| } | ||
| if !sess.isCurrentConn(conn) { | ||
| // 旧连接消息不再分发给新连接请求 | ||
| return | ||
| } | ||
|
|
||
| sess.activeMu.Lock() | ||
| ch := sess.activeCh | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { | |
| var lastRequest []byte | ||
| lastResponseOutput := []byte("[]") | ||
| pinnedAuthID := "" | ||
| forceDisableIncrementalAfterAuthReset := false | ||
|
|
||
| for { | ||
| msgType, payload, errReadMessage := conn.ReadMessage() | ||
|
|
@@ -107,16 +108,18 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { | |
| appendWebsocketEvent(&wsBodyLog, "request", payload) | ||
|
|
||
| allowIncrementalInputWithPreviousResponseID := false | ||
| if pinnedAuthID != "" && h != nil && h.AuthManager != nil { | ||
| if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { | ||
| allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) | ||
| } | ||
| } else { | ||
| requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) | ||
| if requestModelName == "" { | ||
| requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) | ||
| if !forceDisableIncrementalAfterAuthReset { | ||
| if pinnedAuthID != "" && h != nil && h.AuthManager != nil { | ||
| if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { | ||
| allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) | ||
| } | ||
| } else { | ||
| requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) | ||
| if requestModelName == "" { | ||
| requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) | ||
| } | ||
| allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) | ||
| } | ||
| allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) | ||
| } | ||
|
|
||
| var requestJSON []byte | ||
|
|
@@ -192,13 +195,26 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { | |
| } | ||
| dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") | ||
|
|
||
| completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) | ||
| completedOutput, terminalStatus, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) | ||
| if errForward != nil { | ||
| wsTerminateErr = errForward | ||
| appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error())) | ||
| log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) | ||
| return | ||
| } | ||
| if shouldResetResponsesWebsocketAuthPin(terminalStatus) { | ||
| // 限额错误后解除 pin 让后续请求重新选可用账号 | ||
| pinnedAuthID = "" | ||
| // 切号恢复阶段先禁用增量模式避免沿用旧账号 response id | ||
| forceDisableIncrementalAfterAuthReset = true | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Setting Useful? React with 👍 / 👎. |
||
| if h != nil && h.AuthManager != nil { | ||
|
Comment on lines
+225
to
+231
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
After a quota-class terminal error this block clears Useful? React with 👍 / 👎. |
||
| // 主动关闭旧上游会话避免继续复用旧账号连接 | ||
| h.AuthManager.CloseExecutionSession(passthroughSessionID) | ||
| } | ||
| } else if forceDisableIncrementalAfterAuthReset && terminalStatus == 0 { | ||
| // 新账号完成一轮后恢复增量模式 | ||
| forceDisableIncrementalAfterAuthReset = false | ||
| } | ||
| lastResponseOutput = completedOutput | ||
| } | ||
| } | ||
|
|
@@ -610,21 +626,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( | |
| errs <-chan *interfaces.ErrorMessage, | ||
| wsBodyLog *strings.Builder, | ||
| sessionID string, | ||
| ) ([]byte, error) { | ||
| ) ([]byte, int, error) { | ||
| completed := false | ||
| completedOutput := []byte("[]") | ||
| terminalStatusCode := 0 | ||
|
|
||
| for { | ||
| select { | ||
| case <-c.Request.Context().Done(): | ||
| cancel(c.Request.Context().Err()) | ||
| return completedOutput, c.Request.Context().Err() | ||
| return completedOutput, terminalStatusCode, c.Request.Context().Err() | ||
| case errMsg, ok := <-errs: | ||
| if !ok { | ||
| errs = nil | ||
| continue | ||
| } | ||
| if errMsg != nil { | ||
| terminalStatusCode = errMsg.StatusCode | ||
| h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) | ||
| markAPIResponseTimestamp(c) | ||
| errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) | ||
|
|
@@ -644,22 +662,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( | |
| // errWrite, | ||
| // ) | ||
| cancel(errMsg.Error) | ||
| return completedOutput, errWrite | ||
| return completedOutput, terminalStatusCode, errWrite | ||
| } | ||
| } | ||
| if errMsg != nil { | ||
| cancel(errMsg.Error) | ||
| } else { | ||
| cancel(nil) | ||
| } | ||
| return completedOutput, nil | ||
| return completedOutput, terminalStatusCode, nil | ||
| case chunk, ok := <-data: | ||
| if !ok { | ||
| if !completed { | ||
| errMsg := &interfaces.ErrorMessage{ | ||
| StatusCode: http.StatusRequestTimeout, | ||
| Error: fmt.Errorf("stream closed before response.completed"), | ||
| } | ||
| terminalStatusCode = errMsg.StatusCode | ||
| h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) | ||
| markAPIResponseTimestamp(c) | ||
| errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) | ||
|
|
@@ -679,13 +698,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( | |
| errWrite, | ||
| ) | ||
| cancel(errMsg.Error) | ||
| return completedOutput, errWrite | ||
| return completedOutput, terminalStatusCode, errWrite | ||
| } | ||
| cancel(errMsg.Error) | ||
| return completedOutput, nil | ||
| return completedOutput, terminalStatusCode, nil | ||
| } | ||
| cancel(nil) | ||
| return completedOutput, nil | ||
| return completedOutput, terminalStatusCode, nil | ||
| } | ||
|
|
||
| payloads := websocketJSONPayloadsFromChunk(chunk) | ||
|
|
@@ -712,13 +731,22 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( | |
| errWrite, | ||
| ) | ||
| cancel(errWrite) | ||
| return completedOutput, errWrite | ||
| return completedOutput, terminalStatusCode, errWrite | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| func shouldResetResponsesWebsocketAuthPin(statusCode int) bool { | ||
| switch statusCode { | ||
| case http.StatusTooManyRequests, http.StatusForbidden, http.StatusPaymentRequired: | ||
| return true | ||
| default: | ||
| return false | ||
| } | ||
| } | ||
|
|
||
| func responseCompletedOutputFromPayload(payload []byte) []byte { | ||
| output := gjson.GetBytes(payload, "response.output") | ||
| if output.Exists() && output.IsArray() { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling
invalidateUpstreamConnduring auth changes closes the old socket immediately, but the oldreadUpstreamLoopcan still wake later and run its error path, which unconditionallyclearActive/closes whatever channel is currently active for the session. If a new request has already installed its ownactiveCh, that channel gets closed andreadCodexWebsocketMessagereturnssession read channel closed, aborting a healthy post-switch request. This race is specific to in-session auth switching and can break the new quota-recovery flow intermittently.Useful? React with 👍 / 👎.