Skip to content
9 changes: 9 additions & 0 deletions internal/runtime/executor/codex_websockets_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1067,12 +1067,21 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
if sess == nil {
return e.dialCodexWebsocket(ctx, auth, wsURL, headers)
}
authID = strings.TrimSpace(authID)

sess.connMu.Lock()
conn := sess.conn
readerConn := sess.readerConn
currentAuthID := strings.TrimSpace(sess.authID)
sess.connMu.Unlock()
if conn != nil && currentAuthID != authID {
// 账号切换时先断开旧连接避免继续复用旧账号
e.invalidateUpstreamConn(sess, conn, "auth_switched", nil)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid tearing down active reader channel on auth switch

Calling invalidateUpstreamConn during auth changes closes the old socket immediately, but the old readUpstreamLoop can still wake later and run its error path, which unconditionally clearActive/closes whatever channel is currently active for the session. If a new request has already installed its own activeCh, that channel gets closed and readCodexWebsocketMessage returns session read channel closed, aborting a healthy post-switch request. This race is specific to in-session auth switching and can break the new quota-recovery flow intermittently.

Useful? React with 👍 / 👎.

conn = nil
readerConn = nil
}
if conn != nil {
// 账号未变化时复用连接减少不必要重连
if readerConn != conn {
sess.connMu.Lock()
sess.readerConn = conn
Expand Down
91 changes: 91 additions & 0 deletions internal/runtime/executor/codex_websockets_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ import (
"context"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
Expand Down Expand Up @@ -201,3 +205,90 @@ func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
t.Fatal("expected websocket proxy function to be nil for direct mode")
}
}

func TestEnsureUpstreamConnReconnectsWhenAuthChanges(t *testing.T) {
var (
mu sync.Mutex
authorizations []string
)
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
mu.Lock()
authorizations = append(authorizations, strings.TrimSpace(r.Header.Get("Authorization")))
mu.Unlock()

go func() {
defer func() {
_ = conn.Close()
}()
for {
if _, _, errRead := conn.ReadMessage(); errRead != nil {
return
}
}
}()
}))
defer server.Close()

wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
executor := NewCodexWebsocketsExecutor(&config.Config{})
sess := executor.getOrCreateSession("test-session")
if sess == nil {
t.Fatal("expected session to be created")
}

auth1 := &cliproxyauth.Auth{ID: "auth-1"}
headers1 := http.Header{}
headers1.Set("Authorization", "Bearer token-1")
conn1, _, errDial1 := executor.ensureUpstreamConn(context.Background(), auth1, sess, auth1.ID, wsURL, headers1)
if errDial1 != nil {
t.Fatalf("first ensureUpstreamConn failed: %v", errDial1)
}
if conn1 == nil {
t.Fatal("first ensureUpstreamConn returned nil connection")
}

auth2 := &cliproxyauth.Auth{ID: "auth-2"}
headers2 := http.Header{}
headers2.Set("Authorization", "Bearer token-2")
conn2, _, errDial2 := executor.ensureUpstreamConn(context.Background(), auth2, sess, auth2.ID, wsURL, headers2)
if errDial2 != nil {
t.Fatalf("second ensureUpstreamConn failed: %v", errDial2)
}
if conn2 == nil {
t.Fatal("second ensureUpstreamConn returned nil connection")
}
if conn1 == conn2 {
t.Fatal("expected auth change to force upstream reconnect")
}

deadline := time.Now().Add(2 * time.Second)
for {
mu.Lock()
count := len(authorizations)
mu.Unlock()
if count >= 2 || time.Now().After(deadline) {
break
}
time.Sleep(10 * time.Millisecond)
}

mu.Lock()
got := append([]string(nil), authorizations...)
mu.Unlock()
if len(got) < 2 {
t.Fatalf("handshake count = %d, want at least 2", len(got))
}
if got[0] != "Bearer token-1" {
t.Fatalf("first Authorization = %q, want %q", got[0], "Bearer token-1")
}
if got[1] != "Bearer token-2" {
t.Fatalf("second Authorization = %q, want %q", got[1], "Bearer token-2")
}

executor.closeExecutionSession(sess, "test_done")
}
38 changes: 29 additions & 9 deletions sdk/api/handlers/openai/openai_responses_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,21 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")

completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
completedOutput, terminalStatus, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
if shouldResetResponsesWebsocketAuthPin(terminalStatus) {
// 限额错误后解除 pin 让后续请求重新选可用账号
pinnedAuthID = ""
if h != nil && h.AuthManager != nil {
Comment on lines +225 to +231

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reset incremental state when unpinning after quota errors

After a quota-class terminal error this block clears pinnedAuthID and closes the execution session, but it leaves request state unchanged; on the next turn, allowIncrementalInputWithPreviousResponseID can still be true (model-level), so normalizeResponseSubsequentRequest may forward a client previous_response_id that belongs to the old account to the newly selected account. In websocket v2 flows that send previous_response_id, this causes the post-switch request to fail with upstream 4xx instead of recovering, so the auth-rotation fix is incomplete unless incremental mode is disabled/reset for at least the next turn.

Useful? React with 👍 / 👎.

// 主动关闭旧上游会话避免继续复用旧账号连接
h.AuthManager.CloseExecutionSession(passthroughSessionID)
}
}
lastResponseOutput = completedOutput
}
}
Expand Down Expand Up @@ -610,21 +618,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder,
sessionID string,
) ([]byte, error) {
) ([]byte, int, error) {
completed := false
completedOutput := []byte("[]")
terminalStatusCode := 0

for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
return completedOutput, terminalStatusCode, c.Request.Context().Err()
case errMsg, ok := <-errs:
if !ok {
errs = nil
continue
}
if errMsg != nil {
terminalStatusCode = errMsg.StatusCode
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
Expand All @@ -644,22 +654,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, terminalStatusCode, errWrite
}
}
if errMsg != nil {
cancel(errMsg.Error)
} else {
cancel(nil)
}
return completedOutput, nil
return completedOutput, terminalStatusCode, nil
case chunk, ok := <-data:
if !ok {
if !completed {
errMsg := &interfaces.ErrorMessage{
StatusCode: http.StatusRequestTimeout,
Error: fmt.Errorf("stream closed before response.completed"),
}
terminalStatusCode = errMsg.StatusCode
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
Expand All @@ -679,13 +690,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, terminalStatusCode, errWrite
}
cancel(errMsg.Error)
return completedOutput, nil
return completedOutput, terminalStatusCode, nil
}
cancel(nil)
return completedOutput, nil
return completedOutput, terminalStatusCode, nil
}

payloads := websocketJSONPayloadsFromChunk(chunk)
Expand All @@ -712,13 +723,22 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
return completedOutput, terminalStatusCode, errWrite
}
}
}
}
}

func shouldResetResponsesWebsocketAuthPin(statusCode int) bool {
switch statusCode {
case http.StatusTooManyRequests, http.StatusForbidden, http.StatusPaymentRequired:
return true
default:
return false
}
}

func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
Expand Down
Loading
Loading