From cc5cb31186e867be6e5a707244488e13dba8aa3a Mon Sep 17 00:00:00 2001 From: AYANGarch Date: Fri, 20 Mar 2026 13:28:02 +0000 Subject: [PATCH 1/8] =?UTF-8?q?fix(websocket):=20=E9=A2=9D=E5=BA=A6?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E5=90=8E=E5=85=81=E8=AE=B8=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E5=86=85=E5=88=87=E6=8D=A2=E5=8F=AF=E7=94=A8=E8=B4=A6=E5=8F=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../executor/codex_websockets_executor.go | 9 ++ .../codex_websockets_executor_test.go | 91 +++++++++++ .../openai/openai_responses_websocket.go | 38 +++-- .../openai/openai_responses_websocket_test.go | 152 +++++++++++++++++- 4 files changed, 280 insertions(+), 10 deletions(-) diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 571a23a1eb..5fe6db759a 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -1067,12 +1067,21 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth * if sess == nil { return e.dialCodexWebsocket(ctx, auth, wsURL, headers) } + authID = strings.TrimSpace(authID) sess.connMu.Lock() conn := sess.conn readerConn := sess.readerConn + currentAuthID := strings.TrimSpace(sess.authID) sess.connMu.Unlock() + if conn != nil && currentAuthID != authID { + // 账号切换时先断开旧连接避免继续复用旧账号 + e.invalidateUpstreamConn(sess, conn, "auth_switched", nil) + conn = nil + readerConn = nil + } if conn != nil { + // 账号未变化时复用连接减少不必要重连 if readerConn != conn { sess.connMu.Lock() sess.readerConn = conn diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index 755ac56ac4..855239edec 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -4,9 +4,13 @@ import ( "context" "net/http" "net/http/httptest" + "strings" + "sync" "testing" + "time" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" @@ -201,3 +205,90 @@ func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) { t.Fatal("expected websocket proxy function to be nil for direct mode") } } + +func TestEnsureUpstreamConnReconnectsWhenAuthChanges(t *testing.T) { + var ( + mu sync.Mutex + authorizations []string + ) + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + mu.Lock() + authorizations = append(authorizations, strings.TrimSpace(r.Header.Get("Authorization"))) + mu.Unlock() + + go func() { + defer func() { + _ = conn.Close() + }() + for { + if _, _, errRead := conn.ReadMessage(); errRead != nil { + return + } + } + }() + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + executor := NewCodexWebsocketsExecutor(&config.Config{}) + sess := executor.getOrCreateSession("test-session") + if sess == nil { + t.Fatal("expected session to be created") + } + + auth1 := &cliproxyauth.Auth{ID: "auth-1"} + headers1 := http.Header{} + headers1.Set("Authorization", "Bearer token-1") + conn1, _, errDial1 := executor.ensureUpstreamConn(context.Background(), auth1, sess, auth1.ID, wsURL, headers1) + if errDial1 != nil { + t.Fatalf("first ensureUpstreamConn failed: %v", errDial1) + } + if conn1 == nil { + t.Fatal("first ensureUpstreamConn returned nil connection") + } + + auth2 := &cliproxyauth.Auth{ID: "auth-2"} + headers2 := http.Header{} + headers2.Set("Authorization", "Bearer token-2") + conn2, _, errDial2 := executor.ensureUpstreamConn(context.Background(), auth2, sess, auth2.ID, wsURL, headers2) + if errDial2 != nil { + t.Fatalf("second ensureUpstreamConn failed: %v", errDial2) + } + if conn2 == nil { + t.Fatal("second ensureUpstreamConn returned nil connection") + } + if conn1 == conn2 { + t.Fatal("expected auth change to force upstream reconnect") + } + + deadline := time.Now().Add(2 * time.Second) + for { + mu.Lock() + count := len(authorizations) + mu.Unlock() + if count >= 2 || time.Now().After(deadline) { + break + } + time.Sleep(10 * time.Millisecond) + } + + mu.Lock() + got := append([]string(nil), authorizations...) + mu.Unlock() + if len(got) < 2 { + t.Fatalf("handshake count = %d, want at least 2", len(got)) + } + if got[0] != "Bearer token-1" { + t.Fatalf("first Authorization = %q, want %q", got[0], "Bearer token-1") + } + if got[1] != "Bearer token-2" { + t.Fatalf("second Authorization = %q, want %q", got[1], "Bearer token-2") + } + + executor.closeExecutionSession(sess, "test_done") +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 5c68f40e15..51b01587eb 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -192,13 +192,21 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") - completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) + completedOutput, terminalStatus, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) if errForward != nil { wsTerminateErr = errForward appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error())) log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) return } + if shouldResetResponsesWebsocketAuthPin(terminalStatus) { + // 限额错误后解除 pin 让后续请求重新选可用账号 + pinnedAuthID = "" + if h != nil && h.AuthManager != nil { + // 主动关闭旧上游会话避免继续复用旧账号连接 + h.AuthManager.CloseExecutionSession(passthroughSessionID) + } + } lastResponseOutput = completedOutput } } @@ -610,21 +618,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errs <-chan *interfaces.ErrorMessage, wsBodyLog *strings.Builder, sessionID string, -) ([]byte, error) { +) ([]byte, int, error) { completed := false completedOutput := []byte("[]") + terminalStatusCode := 0 for { select { case <-c.Request.Context().Done(): cancel(c.Request.Context().Err()) - return completedOutput, c.Request.Context().Err() + return completedOutput, terminalStatusCode, c.Request.Context().Err() case errMsg, ok := <-errs: if !ok { errs = nil continue } if errMsg != nil { + terminalStatusCode = errMsg.StatusCode h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) @@ -644,7 +654,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( // errWrite, // ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, terminalStatusCode, errWrite } } if errMsg != nil { @@ -652,7 +662,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( } else { cancel(nil) } - return completedOutput, nil + return completedOutput, terminalStatusCode, nil case chunk, ok := <-data: if !ok { if !completed { @@ -660,6 +670,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( StatusCode: http.StatusRequestTimeout, Error: fmt.Errorf("stream closed before response.completed"), } + terminalStatusCode = errMsg.StatusCode h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) @@ -679,13 +690,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, terminalStatusCode, errWrite } cancel(errMsg.Error) - return completedOutput, nil + return completedOutput, terminalStatusCode, nil } cancel(nil) - return completedOutput, nil + return completedOutput, terminalStatusCode, nil } payloads := websocketJSONPayloadsFromChunk(chunk) @@ -712,13 +723,22 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errWrite) - return completedOutput, errWrite + return completedOutput, terminalStatusCode, errWrite } } } } } +func shouldResetResponsesWebsocketAuthPin(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, http.StatusForbidden, http.StatusPaymentRequired: + return true + default: + return false + } +} + func responseCompletedOutputFromPayload(payload []byte) []byte { output := gjson.GetBytes(payload, "response.output") if output.Exists() && output.IsArray() { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index b3a32c5c9d..a726d9ee04 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -62,6 +62,27 @@ type websocketAuthCaptureExecutor struct { authIDs []string } +type websocketStatusError struct { + code int + msg string +} + +func (e websocketStatusError) Error() string { + if strings.TrimSpace(e.msg) != "" { + return e.msg + } + return fmt.Sprintf("status %d", e.code) +} + +func (e websocketStatusError) StatusCode() int { + return e.code +} + +type websocketQuotaSwitchExecutor struct { + mu sync.Mutex + authIDs []string +} + func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -99,6 +120,47 @@ func (e *websocketAuthCaptureExecutor) AuthIDs() []string { return append([]string(nil), e.authIDs...) } +func (e *websocketQuotaSwitchExecutor) Identifier() string { return "test-provider" } + +func (e *websocketQuotaSwitchExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketQuotaSwitchExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, _ coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { + e.mu.Lock() + if auth != nil { + e.authIDs = append(e.authIDs, auth.ID) + } + e.mu.Unlock() + + if auth != nil && auth.ID == "auth-1" { + return nil, websocketStatusError{code: http.StatusTooManyRequests, msg: "quota exhausted"} + } + + chunks := make(chan coreexecutor.StreamChunk, 1) + chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *websocketQuotaSwitchExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketQuotaSwitchExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketQuotaSwitchExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + +func (e *websocketQuotaSwitchExecutor) AuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + return append([]string(nil), e.authIDs...) +} + func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -417,7 +479,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { close(errCh) var bodyLog strings.Builder - completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( + completedOutput, statusCode, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( ctx, conn, func(...interface{}) {}, @@ -430,6 +492,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { serverErrCh <- err return } + if statusCode != 0 { + serverErrCh <- fmt.Errorf("status code = %d, want 0", statusCode) + return + } if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" { serverErrCh <- errors.New("completed output not captured") return @@ -662,3 +728,87 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) { t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got) } } + +func TestResponsesWebsocketClearsPinAndSwitchesAuthAfterQuotaError(t *testing.T) { + gin.SetMode(gin.TestMode) + + selector := &orderedWebsocketSelector{order: []string{"auth-1", "auth-2"}} + executor := &websocketQuotaSwitchExecutor{} + manager := coreauth.NewManager(nil, selector, nil) + manager.SetRetryConfig(0, 0, 1) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth-1", + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("Register auth-1: %v", err) + } + auth2 := &coreauth.Auth{ + ID: "auth-2", + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), auth2); err != nil { + t.Fatalf("Register auth-2: %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + registry.GetGlobalRegistry().UnregisterClient(auth2.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + if errClose := conn.Close(); errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`)); errWrite != nil { + t.Fatalf("write first websocket message: %v", errWrite) + } + _, firstPayload, errReadFirst := conn.ReadMessage() + if errReadFirst != nil { + t.Fatalf("read first websocket message: %v", errReadFirst) + } + if got := gjson.GetBytes(firstPayload, "type").String(); got != wsEventTypeError { + t.Fatalf("first payload type = %s, want %s", got, wsEventTypeError) + } + if got := gjson.GetBytes(firstPayload, "error.code").String(); got != "rate_limit_exceeded" { + t.Fatalf("first payload code = %s, want rate_limit_exceeded", got) + } + + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-2"}]}`)); errWrite != nil { + t.Fatalf("write second websocket message: %v", errWrite) + } + _, secondPayload, errReadSecond := conn.ReadMessage() + if errReadSecond != nil { + t.Fatalf("read second websocket message: %v", errReadSecond) + } + if got := gjson.GetBytes(secondPayload, "type").String(); got != wsEventTypeCompleted { + t.Fatalf("second payload type = %s, want %s", got, wsEventTypeCompleted) + } + + if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-1" || got[1] != "auth-2" { + t.Fatalf("selected auth IDs = %v, want [auth-1 auth-2]", got) + } +} From 32e8d1dff7dff3801369db7772fda491d07c7eca Mon Sep 17 00:00:00 2001 From: AYANGarch Date: Fri, 20 Mar 2026 13:32:46 +0000 Subject: [PATCH 2/8] =?UTF-8?q?chore(test):=20=E7=A7=BB=E9=99=A4=E6=9C=AC?= =?UTF-8?q?=E6=AC=A1=E6=96=B0=E5=A2=9E=E5=9B=9E=E5=BD=92=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../codex_websockets_executor_test.go | 91 ----------- .../openai/openai_responses_websocket_test.go | 146 ------------------ 2 files changed, 237 deletions(-) diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index 855239edec..755ac56ac4 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -4,13 +4,9 @@ import ( "context" "net/http" "net/http/httptest" - "strings" - "sync" "testing" - "time" "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" @@ -205,90 +201,3 @@ func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) { t.Fatal("expected websocket proxy function to be nil for direct mode") } } - -func TestEnsureUpstreamConnReconnectsWhenAuthChanges(t *testing.T) { - var ( - mu sync.Mutex - authorizations []string - ) - upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - mu.Lock() - authorizations = append(authorizations, strings.TrimSpace(r.Header.Get("Authorization"))) - mu.Unlock() - - go func() { - defer func() { - _ = conn.Close() - }() - for { - if _, _, errRead := conn.ReadMessage(); errRead != nil { - return - } - } - }() - })) - defer server.Close() - - wsURL := "ws" + strings.TrimPrefix(server.URL, "http") - executor := NewCodexWebsocketsExecutor(&config.Config{}) - sess := executor.getOrCreateSession("test-session") - if sess == nil { - t.Fatal("expected session to be created") - } - - auth1 := &cliproxyauth.Auth{ID: "auth-1"} - headers1 := http.Header{} - headers1.Set("Authorization", "Bearer token-1") - conn1, _, errDial1 := executor.ensureUpstreamConn(context.Background(), auth1, sess, auth1.ID, wsURL, headers1) - if errDial1 != nil { - t.Fatalf("first ensureUpstreamConn failed: %v", errDial1) - } - if conn1 == nil { - t.Fatal("first ensureUpstreamConn returned nil connection") - } - - auth2 := &cliproxyauth.Auth{ID: "auth-2"} - headers2 := http.Header{} - headers2.Set("Authorization", "Bearer token-2") - conn2, _, errDial2 := executor.ensureUpstreamConn(context.Background(), auth2, sess, auth2.ID, wsURL, headers2) - if errDial2 != nil { - t.Fatalf("second ensureUpstreamConn failed: %v", errDial2) - } - if conn2 == nil { - t.Fatal("second ensureUpstreamConn returned nil connection") - } - if conn1 == conn2 { - t.Fatal("expected auth change to force upstream reconnect") - } - - deadline := time.Now().Add(2 * time.Second) - for { - mu.Lock() - count := len(authorizations) - mu.Unlock() - if count >= 2 || time.Now().After(deadline) { - break - } - time.Sleep(10 * time.Millisecond) - } - - mu.Lock() - got := append([]string(nil), authorizations...) - mu.Unlock() - if len(got) < 2 { - t.Fatalf("handshake count = %d, want at least 2", len(got)) - } - if got[0] != "Bearer token-1" { - t.Fatalf("first Authorization = %q, want %q", got[0], "Bearer token-1") - } - if got[1] != "Bearer token-2" { - t.Fatalf("second Authorization = %q, want %q", got[1], "Bearer token-2") - } - - executor.closeExecutionSession(sess, "test_done") -} diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index a726d9ee04..ef084eeb5a 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -62,27 +62,6 @@ type websocketAuthCaptureExecutor struct { authIDs []string } -type websocketStatusError struct { - code int - msg string -} - -func (e websocketStatusError) Error() string { - if strings.TrimSpace(e.msg) != "" { - return e.msg - } - return fmt.Sprintf("status %d", e.code) -} - -func (e websocketStatusError) StatusCode() int { - return e.code -} - -type websocketQuotaSwitchExecutor struct { - mu sync.Mutex - authIDs []string -} - func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -120,47 +99,6 @@ func (e *websocketAuthCaptureExecutor) AuthIDs() []string { return append([]string(nil), e.authIDs...) } -func (e *websocketQuotaSwitchExecutor) Identifier() string { return "test-provider" } - -func (e *websocketQuotaSwitchExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { - return coreexecutor.Response{}, errors.New("not implemented") -} - -func (e *websocketQuotaSwitchExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, _ coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { - e.mu.Lock() - if auth != nil { - e.authIDs = append(e.authIDs, auth.ID) - } - e.mu.Unlock() - - if auth != nil && auth.ID == "auth-1" { - return nil, websocketStatusError{code: http.StatusTooManyRequests, msg: "quota exhausted"} - } - - chunks := make(chan coreexecutor.StreamChunk, 1) - chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)} - close(chunks) - return &coreexecutor.StreamResult{Chunks: chunks}, nil -} - -func (e *websocketQuotaSwitchExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { - return auth, nil -} - -func (e *websocketQuotaSwitchExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { - return coreexecutor.Response{}, errors.New("not implemented") -} - -func (e *websocketQuotaSwitchExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { - return nil, errors.New("not implemented") -} - -func (e *websocketQuotaSwitchExecutor) AuthIDs() []string { - e.mu.Lock() - defer e.mu.Unlock() - return append([]string(nil), e.authIDs...) -} - func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -728,87 +666,3 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) { t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got) } } - -func TestResponsesWebsocketClearsPinAndSwitchesAuthAfterQuotaError(t *testing.T) { - gin.SetMode(gin.TestMode) - - selector := &orderedWebsocketSelector{order: []string{"auth-1", "auth-2"}} - executor := &websocketQuotaSwitchExecutor{} - manager := coreauth.NewManager(nil, selector, nil) - manager.SetRetryConfig(0, 0, 1) - manager.RegisterExecutor(executor) - - auth1 := &coreauth.Auth{ - ID: "auth-1", - Provider: executor.Identifier(), - Status: coreauth.StatusActive, - Attributes: map[string]string{"websockets": "true"}, - } - if _, err := manager.Register(context.Background(), auth1); err != nil { - t.Fatalf("Register auth-1: %v", err) - } - auth2 := &coreauth.Auth{ - ID: "auth-2", - Provider: executor.Identifier(), - Status: coreauth.StatusActive, - Attributes: map[string]string{"websockets": "true"}, - } - if _, err := manager.Register(context.Background(), auth2); err != nil { - t.Fatalf("Register auth-2: %v", err) - } - - registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) - registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) - t.Cleanup(func() { - registry.GetGlobalRegistry().UnregisterClient(auth1.ID) - registry.GetGlobalRegistry().UnregisterClient(auth2.ID) - }) - - base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) - h := NewOpenAIResponsesAPIHandler(base) - router := gin.New() - router.GET("/v1/responses/ws", h.ResponsesWebsocket) - - server := httptest.NewServer(router) - defer server.Close() - - wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" - conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - if err != nil { - t.Fatalf("dial websocket: %v", err) - } - defer func() { - if errClose := conn.Close(); errClose != nil { - t.Fatalf("close websocket: %v", errClose) - } - }() - - if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`)); errWrite != nil { - t.Fatalf("write first websocket message: %v", errWrite) - } - _, firstPayload, errReadFirst := conn.ReadMessage() - if errReadFirst != nil { - t.Fatalf("read first websocket message: %v", errReadFirst) - } - if got := gjson.GetBytes(firstPayload, "type").String(); got != wsEventTypeError { - t.Fatalf("first payload type = %s, want %s", got, wsEventTypeError) - } - if got := gjson.GetBytes(firstPayload, "error.code").String(); got != "rate_limit_exceeded" { - t.Fatalf("first payload code = %s, want rate_limit_exceeded", got) - } - - if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-2"}]}`)); errWrite != nil { - t.Fatalf("write second websocket message: %v", errWrite) - } - _, secondPayload, errReadSecond := conn.ReadMessage() - if errReadSecond != nil { - t.Fatalf("read second websocket message: %v", errReadSecond) - } - if got := gjson.GetBytes(secondPayload, "type").String(); got != wsEventTypeCompleted { - t.Fatalf("second payload type = %s, want %s", got, wsEventTypeCompleted) - } - - if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-1" || got[1] != "auth-2" { - t.Fatalf("selected auth IDs = %v, want [auth-1 auth-2]", got) - } -} From b9b9f77ed8336f921b1d7992c560b4e50336cb21 Mon Sep 17 00:00:00 2001 From: AYANGarch Date: Fri, 20 Mar 2026 15:04:23 +0000 Subject: [PATCH 3/8] =?UTF-8?q?fix(websocket):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=88=87=E5=8F=B7=E5=90=8E=E7=9A=84=E5=A2=9E=E9=87=8F=E6=B1=A1?= =?UTF-8?q?=E6=9F=93=E4=B8=8E=E8=AF=BB=E9=80=9A=E9=81=93=E7=AB=9E=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../executor/codex_websockets_executor.go | 26 +++++++++++++++++++ .../openai/openai_responses_websocket.go | 26 ++++++++++++------- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 5fe6db759a..41aae255c3 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -117,6 +117,16 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { s.activeMu.Unlock() } +func (s *codexWebsocketSession) isCurrentConn(conn *websocket.Conn) bool { + if s == nil || conn == nil { + return false + } + s.connMu.Lock() + current := s.conn + s.connMu.Unlock() + return current == conn +} + func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { if s == nil { return fmt.Errorf("codex websockets executor: session is nil") @@ -1123,9 +1133,17 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, return } for { + if !sess.isCurrentConn(conn) { + // 旧连接读循环直接退出避免误伤新请求通道 + return + } _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) msgType, payload, errRead := conn.ReadMessage() if errRead != nil { + if !sess.isCurrentConn(conn) { + // 旧连接读错时不触碰当前活跃通道 + return + } sess.activeMu.Lock() ch := sess.activeCh done := sess.activeDone @@ -1146,6 +1164,10 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, if msgType != websocket.TextMessage { if msgType == websocket.BinaryMessage { errBinary := fmt.Errorf("codex websockets executor: unexpected binary message") + if !sess.isCurrentConn(conn) { + // 旧连接二进制异常时不触碰当前活跃通道 + return + } sess.activeMu.Lock() ch := sess.activeCh done := sess.activeDone @@ -1164,6 +1186,10 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, } continue } + if !sess.isCurrentConn(conn) { + // 旧连接消息不再分发给新连接请求 + return + } sess.activeMu.Lock() ch := sess.activeCh diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 51b01587eb..a95995bdac 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -81,6 +81,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { var lastRequest []byte lastResponseOutput := []byte("[]") pinnedAuthID := "" + forceDisableIncrementalAfterAuthReset := false for { msgType, payload, errReadMessage := conn.ReadMessage() @@ -107,16 +108,18 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { appendWebsocketEvent(&wsBodyLog, "request", payload) allowIncrementalInputWithPreviousResponseID := false - if pinnedAuthID != "" && h != nil && h.AuthManager != nil { - if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { - allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) - } - } else { - requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) - if requestModelName == "" { - requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + if !forceDisableIncrementalAfterAuthReset { + if pinnedAuthID != "" && h != nil && h.AuthManager != nil { + if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { + allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) + } + } else { + requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if requestModelName == "" { + requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + } + allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) } - allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) } var requestJSON []byte @@ -202,10 +205,15 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { if shouldResetResponsesWebsocketAuthPin(terminalStatus) { // 限额错误后解除 pin 让后续请求重新选可用账号 pinnedAuthID = "" + // 切号恢复阶段先禁用增量模式避免沿用旧账号 response id + forceDisableIncrementalAfterAuthReset = true if h != nil && h.AuthManager != nil { // 主动关闭旧上游会话避免继续复用旧账号连接 h.AuthManager.CloseExecutionSession(passthroughSessionID) } + } else if forceDisableIncrementalAfterAuthReset && terminalStatus == 0 { + // 新账号完成一轮后恢复增量模式 + forceDisableIncrementalAfterAuthReset = false } lastResponseOutput = completedOutput } From ae8557efd5b39ac2ec4e10b88bd2bb014110dcb9 Mon Sep 17 00:00:00 2001 From: AYANGarch Date: Fri, 20 Mar 2026 15:29:42 +0000 Subject: [PATCH 4/8] =?UTF-8?q?fix(websocket):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=88=87=E5=8F=B7=E6=81=A2=E5=A4=8D=E6=97=B6=E4=B8=8A=E4=B8=8B?= =?UTF-8?q?=E6=96=87=E5=BF=AB=E7=85=A7=E6=88=AA=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../openai/openai_responses_websocket.go | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index a95995bdac..7f4ea9c0a3 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -154,6 +154,28 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } + nextSessionRequestSnapshot := updatedLastRequest + if shouldBuildResponsesWebsocketFullSnapshot(requestJSON, allowIncrementalInputWithPreviousResponseID) { + _, shadowLastRequest, shadowErr := normalizeResponsesWebsocketRequestWithMode( + payload, + lastRequest, + lastResponseOutput, + false, + ) + if shadowErr != nil { + // 影子快照失败时保留旧快照避免污染会话状态 + nextSessionRequestSnapshot = lastRequest + log.Warnf( + "responses websocket: keep previous snapshot id=%s status=%d error=%v", + passthroughSessionID, + shadowErr.StatusCode, + shadowErr.Error, + ) + } else { + // 增量模式只发 delta 同时维护完整快照供切号恢复 + nextSessionRequestSnapshot = shadowLastRequest + } + } if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) { if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil { requestJSON = updated @@ -170,7 +192,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } - lastRequest = updatedLastRequest + // lastRequest 始终保存完整 transcript 快照 + lastRequest = nextSessionRequestSnapshot modelName := gjson.GetBytes(requestJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) @@ -504,6 +527,14 @@ func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest [] return generateResult.Exists() && !generateResult.Bool() } +func shouldBuildResponsesWebsocketFullSnapshot(normalizedRequestJSON []byte, allowIncrementalInputWithPreviousResponseID bool) bool { + if !allowIncrementalInputWithPreviousResponseID { + return false + } + prev := strings.TrimSpace(gjson.GetBytes(normalizedRequestJSON, "previous_response_id").String()) + return prev != "" +} + func writeResponsesWebsocketSyntheticPrewarm( c *gin.Context, conn *websocket.Conn, From 42acfa729d2eed661a2c5379ecb14475ae3bbe9f Mon Sep 17 00:00:00 2001 From: AYANGarch Date: Sat, 21 Mar 2026 01:53:07 +0000 Subject: [PATCH 5/8] =?UTF-8?q?fix(websocket):=20=E6=94=B6=E6=95=9B?= =?UTF-8?q?=E8=AF=BB=E5=BE=AA=E7=8E=AF=E9=80=9A=E9=81=93=E5=BD=92=E5=B1=9E?= =?UTF-8?q?=E7=AB=9E=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../executor/codex_websockets_executor.go | 78 ++++++++++++++----- 1 file changed, 58 insertions(+), 20 deletions(-) diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 41aae255c3..5b23490012 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -127,6 +127,49 @@ func (s *codexWebsocketSession) isCurrentConn(conn *websocket.Conn) bool { return current == conn } +func (s *codexWebsocketSession) activeSnapshotForCurrentConn(conn *websocket.Conn) (chan codexWebsocketRead, <-chan struct{}, bool) { + if s == nil || conn == nil { + return nil, nil, false + } + s.connMu.Lock() + if s.conn != conn { + s.connMu.Unlock() + return nil, nil, false + } + s.activeMu.Lock() + ch := s.activeCh + done := s.activeDone + s.activeMu.Unlock() + s.connMu.Unlock() + return ch, done, true +} + +func (s *codexWebsocketSession) clearActiveForCurrentConn(conn *websocket.Conn, ch chan codexWebsocketRead) bool { + if s == nil || conn == nil || ch == nil { + return false + } + s.connMu.Lock() + if s.conn != conn { + s.connMu.Unlock() + return false + } + s.activeMu.Lock() + if s.activeCh != ch { + s.activeMu.Unlock() + s.connMu.Unlock() + return false + } + s.activeCh = nil + if s.activeCancel != nil { + s.activeCancel() + } + s.activeCancel = nil + s.activeDone = nil + s.activeMu.Unlock() + s.connMu.Unlock() + return true +} + func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { if s == nil { return fmt.Errorf("codex websockets executor: session is nil") @@ -1140,22 +1183,21 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) msgType, payload, errRead := conn.ReadMessage() if errRead != nil { - if !sess.isCurrentConn(conn) { + // 在同一临界区做归属校验和通道快照避免检查后竞态 + ch, done, current := sess.activeSnapshotForCurrentConn(conn) + if !current { // 旧连接读错时不触碰当前活跃通道 return } - sess.activeMu.Lock() - ch := sess.activeCh - done := sess.activeDone - sess.activeMu.Unlock() if ch != nil { select { case ch <- codexWebsocketRead{conn: conn, err: errRead}: case <-done: default: } - sess.clearActive(ch) - close(ch) + if sess.clearActiveForCurrentConn(conn, ch) { + close(ch) + } } e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead) return @@ -1164,37 +1206,33 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, if msgType != websocket.TextMessage { if msgType == websocket.BinaryMessage { errBinary := fmt.Errorf("codex websockets executor: unexpected binary message") - if !sess.isCurrentConn(conn) { + // 在同一临界区做归属校验和通道快照避免检查后竞态 + ch, done, current := sess.activeSnapshotForCurrentConn(conn) + if !current { // 旧连接二进制异常时不触碰当前活跃通道 return } - sess.activeMu.Lock() - ch := sess.activeCh - done := sess.activeDone - sess.activeMu.Unlock() if ch != nil { select { case ch <- codexWebsocketRead{conn: conn, err: errBinary}: case <-done: default: } - sess.clearActive(ch) - close(ch) + if sess.clearActiveForCurrentConn(conn, ch) { + close(ch) + } } e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary) return } continue } - if !sess.isCurrentConn(conn) { + // 在同一临界区做归属校验和通道快照避免检查后竞态 + ch, done, current := sess.activeSnapshotForCurrentConn(conn) + if !current { // 旧连接消息不再分发给新连接请求 return } - - sess.activeMu.Lock() - ch := sess.activeCh - done := sess.activeDone - sess.activeMu.Unlock() if ch == nil { continue } From e9a98a38384ed2ed17ad2ab4b41c525c687c36ec Mon Sep 17 00:00:00 2001 From: AYANGarch Date: Sat, 21 Mar 2026 02:34:51 +0000 Subject: [PATCH 6/8] =?UTF-8?q?fix(websocket):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E4=BC=9A=E8=AF=9D=E5=85=B3=E9=97=AD=E6=8C=82=E8=B5=B7=E5=B9=B6?= =?UTF-8?q?=E8=A1=A5=E5=85=B3=E9=94=AE=E5=9B=9E=E5=BD=92=E5=8D=95=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../executor/codex_websockets_executor.go | 69 ++++++++++++--- .../codex_websockets_executor_test.go | 88 +++++++++++++++++++ .../openai/openai_responses_websocket_test.go | 28 ++++++ 3 files changed, 172 insertions(+), 13 deletions(-) diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 5b23490012..709200ef72 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -82,10 +82,31 @@ type codexWebsocketRead struct { err error } +func trySendCodexWebsocketRead(ch chan codexWebsocketRead, done <-chan struct{}, ev codexWebsocketRead) { + if ch == nil { + return + } + defer func() { _ = recover() }() + select { + case ch <- ev: + case <-done: + default: + } +} + +func tryCloseCodexWebsocketRead(ch chan codexWebsocketRead) { + if ch == nil { + return + } + defer func() { _ = recover() }() + close(ch) +} + func (s *codexWebsocketSession) setActive(ch chan codexWebsocketRead) { if s == nil { return } + // 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突 s.activeMu.Lock() if s.activeCancel != nil { s.activeCancel() @@ -105,6 +126,7 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { if s == nil { return } + // 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突 s.activeMu.Lock() if s.activeCh == ch { s.activeCh = nil @@ -131,6 +153,7 @@ func (s *codexWebsocketSession) activeSnapshotForCurrentConn(conn *websocket.Con if s == nil || conn == nil { return nil, nil, false } + // 锁顺序固定为 connMu -> activeMu s.connMu.Lock() if s.conn != conn { s.connMu.Unlock() @@ -148,6 +171,7 @@ func (s *codexWebsocketSession) clearActiveForCurrentConn(conn *websocket.Conn, if s == nil || conn == nil || ch == nil { return false } + // 锁顺序固定为 connMu -> activeMu s.connMu.Lock() if s.conn != conn { s.connMu.Unlock() @@ -170,6 +194,30 @@ func (s *codexWebsocketSession) clearActiveForCurrentConn(conn *websocket.Conn, return true } +func (s *codexWebsocketSession) failActiveForSessionClose(conn *websocket.Conn, err error) { + if s == nil { + return + } + if err == nil { + err = fmt.Errorf("codex websockets executor: execution session closed") + } + s.activeMu.Lock() + ch := s.activeCh + done := s.activeDone + if s.activeCancel != nil { + s.activeCancel() + } + s.activeCh = nil + s.activeCancel = nil + s.activeDone = nil + s.activeMu.Unlock() + if ch == nil { + return + } + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: err}) + tryCloseCodexWebsocketRead(ch) +} + func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { if s == nil { return fmt.Errorf("codex websockets executor: session is nil") @@ -1117,10 +1165,10 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb } func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { + authID = strings.TrimSpace(authID) if sess == nil { return e.dialCodexWebsocket(ctx, auth, wsURL, headers) } - authID = strings.TrimSpace(authID) sess.connMu.Lock() conn := sess.conn @@ -1190,13 +1238,9 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, return } if ch != nil { - select { - case ch <- codexWebsocketRead{conn: conn, err: errRead}: - case <-done: - default: - } + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: errRead}) if sess.clearActiveForCurrentConn(conn, ch) { - close(ch) + tryCloseCodexWebsocketRead(ch) } } e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead) @@ -1213,13 +1257,9 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, return } if ch != nil { - select { - case ch <- codexWebsocketRead{conn: conn, err: errBinary}: - case <-done: - default: - } + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: errBinary}) if sess.clearActiveForCurrentConn(conn, ch) { - close(ch) + tryCloseCodexWebsocketRead(ch) } } e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary) @@ -1330,6 +1370,9 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess sessionID := sess.sessionID sess.connMu.Unlock() + // 会话显式关闭时主动唤醒活跃请求避免 readCh 悬挂 + sess.failActiveForSessionClose(conn, fmt.Errorf("codex websockets executor: execution session closed")) + if conn == nil { return } diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index 755ac56ac4..f5fc26bb6b 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -4,9 +4,12 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" + "time" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" @@ -201,3 +204,88 @@ func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) { t.Fatal("expected websocket proxy function to be nil for direct mode") } } + +func TestReadCodexWebsocketMessageReturnsWhenReadChannelClosed(t *testing.T) { + t.Parallel() + + sess := &codexWebsocketSession{} + conn := &websocket.Conn{} + readCh := make(chan codexWebsocketRead) + close(readCh) + + _, _, err := readCodexWebsocketMessage(context.Background(), sess, conn, readCh) + if err == nil { + t.Fatal("expected error when session read channel is closed") + } + if !strings.Contains(err.Error(), "session read channel closed") { + t.Fatalf("error = %v, want contains session read channel closed", err) + } +} + +func TestCloseExecutionSessionUnblocksActiveRead(t *testing.T) { + t.Parallel() + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + serverConnCh := make(chan *websocket.Conn, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + serverConnCh <- conn + _, _, _ = conn.ReadMessage() + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, _, errDial := websocket.DefaultDialer.Dial(wsURL, nil) + if errDial != nil { + t.Fatalf("dial websocket: %v", errDial) + } + defer func() { _ = clientConn.Close() }() + + var serverConn *websocket.Conn + select { + case serverConn = <-serverConnCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for server websocket connection") + } + + sess := &codexWebsocketSession{ + sessionID: "session-close", + conn: serverConn, + readerConn: serverConn, + } + readCh := make(chan codexWebsocketRead, 4) + sess.setActive(readCh) + + executor := &CodexWebsocketsExecutor{ + CodexExecutor: &CodexExecutor{}, + sessions: map[string]*codexWebsocketSession{ + "session-close": sess, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + readErrCh := make(chan error, 1) + go func() { + _, _, err := readCodexWebsocketMessage(ctx, sess, serverConn, readCh) + readErrCh <- err + }() + + executor.CloseExecutionSession("session-close") + + select { + case err := <-readErrCh: + if err == nil { + t.Fatal("expected read error after closing execution session") + } + errText := err.Error() + if !strings.Contains(errText, "execution session closed") && !strings.Contains(errText, "session read channel closed") { + t.Fatalf("error = %v, want fast-fail error from session close path", err) + } + case <-time.After(3 * time.Second): + t.Fatal("read did not fail fast after closeExecutionSession") + } +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index ef084eeb5a..fee888d596 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -390,6 +390,34 @@ func TestSetWebsocketRequestBody(t *testing.T) { } } +func TestShouldResetResponsesWebsocketAuthPin(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + statusCode int + want bool + }{ + {name: "too_many_requests", statusCode: http.StatusTooManyRequests, want: true}, + {name: "forbidden", statusCode: http.StatusForbidden, want: true}, + {name: "payment_required", statusCode: http.StatusPaymentRequired, want: true}, + {name: "unauthorized", statusCode: http.StatusUnauthorized, want: false}, + {name: "internal_error", statusCode: http.StatusInternalServerError, want: false}, + {name: "zero", statusCode: 0, want: false}, + } + + for i := range cases { + tc := cases[i] + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := shouldResetResponsesWebsocketAuthPin(tc.statusCode) + if got != tc.want { + t.Fatalf("shouldResetResponsesWebsocketAuthPin(%d) = %v, want %v", tc.statusCode, got, tc.want) + } + }) + } +} + func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { gin.SetMode(gin.TestMode) From 51ea8f1e9b1eb4994c8198cdaf46fb911c6d9ce2 Mon Sep 17 00:00:00 2001 From: AYANGarch Date: Sat, 21 Mar 2026 02:46:47 +0000 Subject: [PATCH 7/8] =?UTF-8?q?fix(websocket):=20=E5=A4=B1=E8=B4=A5?= =?UTF-8?q?=E8=BD=AE=E6=AC=A1=E4=B8=8D=E6=8F=90=E4=BA=A4=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E5=BF=AB=E7=85=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdk/api/handlers/openai/openai_responses_websocket.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 7f4ea9c0a3..f161e11eca 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -192,9 +192,6 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } - // lastRequest 始终保存完整 transcript 快照 - lastRequest = nextSessionRequestSnapshot - modelName := gjson.GetBytes(requestJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx) @@ -238,7 +235,12 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { // 新账号完成一轮后恢复增量模式 forceDisableIncrementalAfterAuthReset = false } - lastResponseOutput = completedOutput + if terminalStatus == 0 { + // 仅在本轮成功后提交快照避免失败轮次污染会话历史 + lastRequest = nextSessionRequestSnapshot + // 仅在本轮成功后提交输出避免失败把状态推进到空输出 + lastResponseOutput = completedOutput + } } } From 7e7cc1a4cc673cf8f452f70d90729ad91e837ae8 Mon Sep 17 00:00:00 2001 From: AYANGarch Date: Sat, 21 Mar 2026 11:07:32 +0000 Subject: [PATCH 8/8] =?UTF-8?q?fix(websocket):=20=E8=B7=9F=E8=BF=9B=20revi?= =?UTF-8?q?ew=20=E8=A1=A5=E9=BD=90=20401=20=E8=A7=A3=20pin=20=E4=B8=8E?= =?UTF-8?q?=E5=8F=AF=E8=A7=82=E6=B5=8B=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../executor/codex_websockets_executor.go | 58 +++++++-------- .../codex_websockets_executor_test.go | 73 +++++++++++++++++++ .../openai/openai_responses_websocket.go | 7 +- .../openai/openai_responses_websocket_test.go | 2 +- 4 files changed, 107 insertions(+), 33 deletions(-) diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 709200ef72..cec74b3c41 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -86,7 +86,11 @@ func trySendCodexWebsocketRead(ch chan codexWebsocketRead, done <-chan struct{}, if ch == nil { return } - defer func() { _ = recover() }() + defer func() { + if r := recover(); r != nil { + log.Debugf("codex websockets executor: recover trySendCodexWebsocketRead panic=%v", r) + } + }() select { case ch <- ev: case <-done: @@ -98,7 +102,11 @@ func tryCloseCodexWebsocketRead(ch chan codexWebsocketRead) { if ch == nil { return } - defer func() { _ = recover() }() + defer func() { + if r := recover(); r != nil { + log.Debugf("codex websockets executor: recover tryCloseCodexWebsocketRead panic=%v", r) + } + }() close(ch) } @@ -107,6 +115,7 @@ func (s *codexWebsocketSession) setActive(ch chan codexWebsocketRead) { return } // 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突 + // 不要在持有 connMu 时调用避免未来引入反向锁序 s.activeMu.Lock() if s.activeCancel != nil { s.activeCancel() @@ -127,6 +136,7 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { return } // 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突 + // 不要在持有 connMu 时调用避免未来引入反向锁序 s.activeMu.Lock() if s.activeCh == ch { s.activeCh = nil @@ -194,30 +204,6 @@ func (s *codexWebsocketSession) clearActiveForCurrentConn(conn *websocket.Conn, return true } -func (s *codexWebsocketSession) failActiveForSessionClose(conn *websocket.Conn, err error) { - if s == nil { - return - } - if err == nil { - err = fmt.Errorf("codex websockets executor: execution session closed") - } - s.activeMu.Lock() - ch := s.activeCh - done := s.activeDone - if s.activeCancel != nil { - s.activeCancel() - } - s.activeCh = nil - s.activeCancel = nil - s.activeDone = nil - s.activeMu.Unlock() - if ch == nil { - return - } - trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: err}) - tryCloseCodexWebsocketRead(ch) -} - func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { if s == nil { return fmt.Errorf("codex websockets executor: session is nil") @@ -1359,19 +1345,33 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess reason = "session_closed" } + // 锁顺序固定为 connMu -> activeMu sess.connMu.Lock() conn := sess.conn authID := sess.authID wsURL := sess.wsURL + sessionID := sess.sessionID sess.conn = nil if sess.readerConn == conn { sess.readerConn = nil } - sessionID := sess.sessionID + sess.activeMu.Lock() + ch := sess.activeCh + done := sess.activeDone + if sess.activeCancel != nil { + sess.activeCancel() + } + sess.activeCh = nil + sess.activeCancel = nil + sess.activeDone = nil + sess.activeMu.Unlock() sess.connMu.Unlock() - // 会话显式关闭时主动唤醒活跃请求避免 readCh 悬挂 - sess.failActiveForSessionClose(conn, fmt.Errorf("codex websockets executor: execution session closed")) + if ch != nil { + // 会话关闭时允许主动 fail active 唤醒在途 readCodexWebsocketMessage + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: fmt.Errorf("codex websockets executor: execution session closed")}) + tryCloseCodexWebsocketRead(ch) + } if conn == nil { return diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index f5fc26bb6b..4935927599 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -289,3 +289,76 @@ func TestCloseExecutionSessionUnblocksActiveRead(t *testing.T) { t.Fatal("read did not fail fast after closeExecutionSession") } } + +func TestEnsureUpstreamConnAuthSwitchRebuildsWebsocketConn(t *testing.T) { + t.Parallel() + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + authHeaderCh := make(chan string, 4) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer func() { _ = conn.Close() }() + + authHeaderCh <- strings.TrimSpace(r.Header.Get("Authorization")) + for { + _, _, errRead := conn.ReadMessage() + if errRead != nil { + return + } + } + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + executor := NewCodexWebsocketsExecutor(&config.Config{}) + sess := &codexWebsocketSession{sessionID: "session-auth-switch"} + + headers1 := http.Header{} + headers1.Set("Authorization", "Bearer token-1") + conn1, _, errDial1 := executor.ensureUpstreamConn(context.Background(), nil, sess, "auth-1", wsURL, headers1) + if errDial1 != nil { + t.Fatalf("ensureUpstreamConn auth-1 error: %v", errDial1) + } + if conn1 == nil { + t.Fatal("ensureUpstreamConn auth-1 returned nil conn") + } + + headers2 := http.Header{} + headers2.Set("Authorization", "Bearer token-2") + conn2, _, errDial2 := executor.ensureUpstreamConn(context.Background(), nil, sess, "auth-2", wsURL, headers2) + if errDial2 != nil { + t.Fatalf("ensureUpstreamConn auth-2 error: %v", errDial2) + } + if conn2 == nil { + t.Fatal("ensureUpstreamConn auth-2 returned nil conn") + } + if conn2 == conn1 { + t.Fatal("expected new websocket conn after auth switch") + } + + defer executor.invalidateUpstreamConn(sess, conn2, "test_done", nil) + + var got1, got2 string + select { + case got1 = <-authHeaderCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first websocket handshake") + } + select { + case got2 = <-authHeaderCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for second websocket handshake") + } + if got1 != "Bearer token-1" { + t.Fatalf("first Authorization = %q, want %q", got1, "Bearer token-1") + } + if got2 != "Bearer token-2" { + t.Fatalf("second Authorization = %q, want %q", got2, "Bearer token-2") + } + if got1 == got2 { + t.Fatal("expected different Authorization headers after auth switch") + } +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index f161e11eca..aaafcae760 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -165,7 +165,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { if shadowErr != nil { // 影子快照失败时保留旧快照避免污染会话状态 nextSessionRequestSnapshot = lastRequest - log.Warnf( + log.Errorf( "responses websocket: keep previous snapshot id=%s status=%d error=%v", passthroughSessionID, shadowErr.StatusCode, @@ -224,6 +224,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } if shouldResetResponsesWebsocketAuthPin(terminalStatus) { // 限额错误后解除 pin 让后续请求重新选可用账号 + log.Infof("responses websocket: reset auth pin id=%s status=%d", passthroughSessionID, terminalStatus) pinnedAuthID = "" // 切号恢复阶段先禁用增量模式避免沿用旧账号 response id forceDisableIncrementalAfterAuthReset = true @@ -232,7 +233,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { h.AuthManager.CloseExecutionSession(passthroughSessionID) } } else if forceDisableIncrementalAfterAuthReset && terminalStatus == 0 { - // 新账号完成一轮后恢复增量模式 + // 仅在成功轮次恢复增量模式避免失败轮次继续透传旧 response id forceDisableIncrementalAfterAuthReset = false } if terminalStatus == 0 { @@ -773,7 +774,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( func shouldResetResponsesWebsocketAuthPin(statusCode int) bool { switch statusCode { - case http.StatusTooManyRequests, http.StatusForbidden, http.StatusPaymentRequired: + case http.StatusTooManyRequests, http.StatusForbidden, http.StatusPaymentRequired, http.StatusUnauthorized: return true default: return false diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index fee888d596..fc55a3ddf5 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -401,7 +401,7 @@ func TestShouldResetResponsesWebsocketAuthPin(t *testing.T) { {name: "too_many_requests", statusCode: http.StatusTooManyRequests, want: true}, {name: "forbidden", statusCode: http.StatusForbidden, want: true}, {name: "payment_required", statusCode: http.StatusPaymentRequired, want: true}, - {name: "unauthorized", statusCode: http.StatusUnauthorized, want: false}, + {name: "unauthorized", statusCode: http.StatusUnauthorized, want: true}, {name: "internal_error", statusCode: http.StatusInternalServerError, want: false}, {name: "zero", statusCode: 0, want: false}, }