diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 571a23a1eb..cec74b3c41 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -82,10 +82,40 @@ type codexWebsocketRead struct { err error } +func trySendCodexWebsocketRead(ch chan codexWebsocketRead, done <-chan struct{}, ev codexWebsocketRead) { + if ch == nil { + return + } + defer func() { + if r := recover(); r != nil { + log.Debugf("codex websockets executor: recover trySendCodexWebsocketRead panic=%v", r) + } + }() + select { + case ch <- ev: + case <-done: + default: + } +} + +func tryCloseCodexWebsocketRead(ch chan codexWebsocketRead) { + if ch == nil { + return + } + defer func() { + if r := recover(); r != nil { + log.Debugf("codex websockets executor: recover tryCloseCodexWebsocketRead panic=%v", r) + } + }() + close(ch) +} + func (s *codexWebsocketSession) setActive(ch chan codexWebsocketRead) { if s == nil { return } + // 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突 + // 不要在持有 connMu 时调用避免未来引入反向锁序 s.activeMu.Lock() if s.activeCancel != nil { s.activeCancel() @@ -105,6 +135,8 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { if s == nil { return } + // 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突 + // 不要在持有 connMu 时调用避免未来引入反向锁序 s.activeMu.Lock() if s.activeCh == ch { s.activeCh = nil @@ -117,6 +149,61 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { s.activeMu.Unlock() } +func (s *codexWebsocketSession) isCurrentConn(conn *websocket.Conn) bool { + if s == nil || conn == nil { + return false + } + s.connMu.Lock() + current := s.conn + s.connMu.Unlock() + return current == conn +} + +func (s *codexWebsocketSession) activeSnapshotForCurrentConn(conn *websocket.Conn) (chan codexWebsocketRead, <-chan struct{}, bool) { + if s == nil || conn == nil { + return nil, nil, false + } + // 锁顺序固定为 connMu -> activeMu + s.connMu.Lock() + if s.conn != conn { + s.connMu.Unlock() + return nil, nil, false + } + s.activeMu.Lock() + ch := s.activeCh + done := s.activeDone + s.activeMu.Unlock() + s.connMu.Unlock() + return ch, done, true +} + +func (s *codexWebsocketSession) clearActiveForCurrentConn(conn *websocket.Conn, ch chan codexWebsocketRead) bool { + if s == nil || conn == nil || ch == nil { + return false + } + // 锁顺序固定为 connMu -> activeMu + s.connMu.Lock() + if s.conn != conn { + s.connMu.Unlock() + return false + } + s.activeMu.Lock() + if s.activeCh != ch { + s.activeMu.Unlock() + s.connMu.Unlock() + return false + } + s.activeCh = nil + if s.activeCancel != nil { + s.activeCancel() + } + s.activeCancel = nil + s.activeDone = nil + s.activeMu.Unlock() + s.connMu.Unlock() + return true +} + func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { if s == nil { return fmt.Errorf("codex websockets executor: session is nil") @@ -1064,6 +1151,7 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb } func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { + authID = strings.TrimSpace(authID) if sess == nil { return e.dialCodexWebsocket(ctx, auth, wsURL, headers) } @@ -1071,8 +1159,16 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth * sess.connMu.Lock() conn := sess.conn readerConn := sess.readerConn + currentAuthID := strings.TrimSpace(sess.authID) sess.connMu.Unlock() + if conn != nil && currentAuthID != authID { + // 账号切换时先断开旧连接避免继续复用旧账号 + e.invalidateUpstreamConn(sess, conn, "auth_switched", nil) + conn = nil + readerConn = nil + } if conn != nil { + // 账号未变化时复用连接减少不必要重连 if readerConn != conn { sess.connMu.Lock() sess.readerConn = conn @@ -1114,21 +1210,24 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, return } for { + if !sess.isCurrentConn(conn) { + // 旧连接读循环直接退出避免误伤新请求通道 + return + } _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) msgType, payload, errRead := conn.ReadMessage() if errRead != nil { - sess.activeMu.Lock() - ch := sess.activeCh - done := sess.activeDone - sess.activeMu.Unlock() + // 在同一临界区做归属校验和通道快照避免检查后竞态 + ch, done, current := sess.activeSnapshotForCurrentConn(conn) + if !current { + // 旧连接读错时不触碰当前活跃通道 + return + } if ch != nil { - select { - case ch <- codexWebsocketRead{conn: conn, err: errRead}: - case <-done: - default: + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: errRead}) + if sess.clearActiveForCurrentConn(conn, ch) { + tryCloseCodexWebsocketRead(ch) } - sess.clearActive(ch) - close(ch) } e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead) return @@ -1137,29 +1236,29 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, if msgType != websocket.TextMessage { if msgType == websocket.BinaryMessage { errBinary := fmt.Errorf("codex websockets executor: unexpected binary message") - sess.activeMu.Lock() - ch := sess.activeCh - done := sess.activeDone - sess.activeMu.Unlock() + // 在同一临界区做归属校验和通道快照避免检查后竞态 + ch, done, current := sess.activeSnapshotForCurrentConn(conn) + if !current { + // 旧连接二进制异常时不触碰当前活跃通道 + return + } if ch != nil { - select { - case ch <- codexWebsocketRead{conn: conn, err: errBinary}: - case <-done: - default: + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: errBinary}) + if sess.clearActiveForCurrentConn(conn, ch) { + tryCloseCodexWebsocketRead(ch) } - sess.clearActive(ch) - close(ch) } e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary) return } continue } - - sess.activeMu.Lock() - ch := sess.activeCh - done := sess.activeDone - sess.activeMu.Unlock() + // 在同一临界区做归属校验和通道快照避免检查后竞态 + ch, done, current := sess.activeSnapshotForCurrentConn(conn) + if !current { + // 旧连接消息不再分发给新连接请求 + return + } if ch == nil { continue } @@ -1246,17 +1345,34 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess reason = "session_closed" } + // 锁顺序固定为 connMu -> activeMu sess.connMu.Lock() conn := sess.conn authID := sess.authID wsURL := sess.wsURL + sessionID := sess.sessionID sess.conn = nil if sess.readerConn == conn { sess.readerConn = nil } - sessionID := sess.sessionID + sess.activeMu.Lock() + ch := sess.activeCh + done := sess.activeDone + if sess.activeCancel != nil { + sess.activeCancel() + } + sess.activeCh = nil + sess.activeCancel = nil + sess.activeDone = nil + sess.activeMu.Unlock() sess.connMu.Unlock() + if ch != nil { + // 会话关闭时允许主动 fail active 唤醒在途 readCodexWebsocketMessage + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: fmt.Errorf("codex websockets executor: execution session closed")}) + tryCloseCodexWebsocketRead(ch) + } + if conn == nil { return } diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index 755ac56ac4..4935927599 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -4,9 +4,12 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" + "time" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" @@ -201,3 +204,161 @@ func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) { t.Fatal("expected websocket proxy function to be nil for direct mode") } } + +func TestReadCodexWebsocketMessageReturnsWhenReadChannelClosed(t *testing.T) { + t.Parallel() + + sess := &codexWebsocketSession{} + conn := &websocket.Conn{} + readCh := make(chan codexWebsocketRead) + close(readCh) + + _, _, err := readCodexWebsocketMessage(context.Background(), sess, conn, readCh) + if err == nil { + t.Fatal("expected error when session read channel is closed") + } + if !strings.Contains(err.Error(), "session read channel closed") { + t.Fatalf("error = %v, want contains session read channel closed", err) + } +} + +func TestCloseExecutionSessionUnblocksActiveRead(t *testing.T) { + t.Parallel() + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + serverConnCh := make(chan *websocket.Conn, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + serverConnCh <- conn + _, _, _ = conn.ReadMessage() + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, _, errDial := websocket.DefaultDialer.Dial(wsURL, nil) + if errDial != nil { + t.Fatalf("dial websocket: %v", errDial) + } + defer func() { _ = clientConn.Close() }() + + var serverConn *websocket.Conn + select { + case serverConn = <-serverConnCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for server websocket connection") + } + + sess := &codexWebsocketSession{ + sessionID: "session-close", + conn: serverConn, + readerConn: serverConn, + } + readCh := make(chan codexWebsocketRead, 4) + sess.setActive(readCh) + + executor := &CodexWebsocketsExecutor{ + CodexExecutor: &CodexExecutor{}, + sessions: map[string]*codexWebsocketSession{ + "session-close": sess, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + readErrCh := make(chan error, 1) + go func() { + _, _, err := readCodexWebsocketMessage(ctx, sess, serverConn, readCh) + readErrCh <- err + }() + + executor.CloseExecutionSession("session-close") + + select { + case err := <-readErrCh: + if err == nil { + t.Fatal("expected read error after closing execution session") + } + errText := err.Error() + if !strings.Contains(errText, "execution session closed") && !strings.Contains(errText, "session read channel closed") { + t.Fatalf("error = %v, want fast-fail error from session close path", err) + } + case <-time.After(3 * time.Second): + t.Fatal("read did not fail fast after closeExecutionSession") + } +} + +func TestEnsureUpstreamConnAuthSwitchRebuildsWebsocketConn(t *testing.T) { + t.Parallel() + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + authHeaderCh := make(chan string, 4) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer func() { _ = conn.Close() }() + + authHeaderCh <- strings.TrimSpace(r.Header.Get("Authorization")) + for { + _, _, errRead := conn.ReadMessage() + if errRead != nil { + return + } + } + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + executor := NewCodexWebsocketsExecutor(&config.Config{}) + sess := &codexWebsocketSession{sessionID: "session-auth-switch"} + + headers1 := http.Header{} + headers1.Set("Authorization", "Bearer token-1") + conn1, _, errDial1 := executor.ensureUpstreamConn(context.Background(), nil, sess, "auth-1", wsURL, headers1) + if errDial1 != nil { + t.Fatalf("ensureUpstreamConn auth-1 error: %v", errDial1) + } + if conn1 == nil { + t.Fatal("ensureUpstreamConn auth-1 returned nil conn") + } + + headers2 := http.Header{} + headers2.Set("Authorization", "Bearer token-2") + conn2, _, errDial2 := executor.ensureUpstreamConn(context.Background(), nil, sess, "auth-2", wsURL, headers2) + if errDial2 != nil { + t.Fatalf("ensureUpstreamConn auth-2 error: %v", errDial2) + } + if conn2 == nil { + t.Fatal("ensureUpstreamConn auth-2 returned nil conn") + } + if conn2 == conn1 { + t.Fatal("expected new websocket conn after auth switch") + } + + defer executor.invalidateUpstreamConn(sess, conn2, "test_done", nil) + + var got1, got2 string + select { + case got1 = <-authHeaderCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first websocket handshake") + } + select { + case got2 = <-authHeaderCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for second websocket handshake") + } + if got1 != "Bearer token-1" { + t.Fatalf("first Authorization = %q, want %q", got1, "Bearer token-1") + } + if got2 != "Bearer token-2" { + t.Fatalf("second Authorization = %q, want %q", got2, "Bearer token-2") + } + if got1 == got2 { + t.Fatal("expected different Authorization headers after auth switch") + } +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 5c68f40e15..aaafcae760 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -81,6 +81,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { var lastRequest []byte lastResponseOutput := []byte("[]") pinnedAuthID := "" + forceDisableIncrementalAfterAuthReset := false for { msgType, payload, errReadMessage := conn.ReadMessage() @@ -107,16 +108,18 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { appendWebsocketEvent(&wsBodyLog, "request", payload) allowIncrementalInputWithPreviousResponseID := false - if pinnedAuthID != "" && h != nil && h.AuthManager != nil { - if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { - allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) - } - } else { - requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) - if requestModelName == "" { - requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + if !forceDisableIncrementalAfterAuthReset { + if pinnedAuthID != "" && h != nil && h.AuthManager != nil { + if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { + allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) + } + } else { + requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if requestModelName == "" { + requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + } + allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) } - allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) } var requestJSON []byte @@ -151,6 +154,28 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } + nextSessionRequestSnapshot := updatedLastRequest + if shouldBuildResponsesWebsocketFullSnapshot(requestJSON, allowIncrementalInputWithPreviousResponseID) { + _, shadowLastRequest, shadowErr := normalizeResponsesWebsocketRequestWithMode( + payload, + lastRequest, + lastResponseOutput, + false, + ) + if shadowErr != nil { + // 影子快照失败时保留旧快照避免污染会话状态 + nextSessionRequestSnapshot = lastRequest + log.Errorf( + "responses websocket: keep previous snapshot id=%s status=%d error=%v", + passthroughSessionID, + shadowErr.StatusCode, + shadowErr.Error, + ) + } else { + // 增量模式只发 delta 同时维护完整快照供切号恢复 + nextSessionRequestSnapshot = shadowLastRequest + } + } if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) { if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil { requestJSON = updated @@ -167,8 +192,6 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } - lastRequest = updatedLastRequest - modelName := gjson.GetBytes(requestJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx) @@ -192,14 +215,33 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") - completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) + completedOutput, terminalStatus, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) if errForward != nil { wsTerminateErr = errForward appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error())) log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) return } - lastResponseOutput = completedOutput + if shouldResetResponsesWebsocketAuthPin(terminalStatus) { + // 限额错误后解除 pin 让后续请求重新选可用账号 + log.Infof("responses websocket: reset auth pin id=%s status=%d", passthroughSessionID, terminalStatus) + pinnedAuthID = "" + // 切号恢复阶段先禁用增量模式避免沿用旧账号 response id + forceDisableIncrementalAfterAuthReset = true + if h != nil && h.AuthManager != nil { + // 主动关闭旧上游会话避免继续复用旧账号连接 + h.AuthManager.CloseExecutionSession(passthroughSessionID) + } + } else if forceDisableIncrementalAfterAuthReset && terminalStatus == 0 { + // 仅在成功轮次恢复增量模式避免失败轮次继续透传旧 response id + forceDisableIncrementalAfterAuthReset = false + } + if terminalStatus == 0 { + // 仅在本轮成功后提交快照避免失败轮次污染会话历史 + lastRequest = nextSessionRequestSnapshot + // 仅在本轮成功后提交输出避免失败把状态推进到空输出 + lastResponseOutput = completedOutput + } } } @@ -488,6 +530,14 @@ func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest [] return generateResult.Exists() && !generateResult.Bool() } +func shouldBuildResponsesWebsocketFullSnapshot(normalizedRequestJSON []byte, allowIncrementalInputWithPreviousResponseID bool) bool { + if !allowIncrementalInputWithPreviousResponseID { + return false + } + prev := strings.TrimSpace(gjson.GetBytes(normalizedRequestJSON, "previous_response_id").String()) + return prev != "" +} + func writeResponsesWebsocketSyntheticPrewarm( c *gin.Context, conn *websocket.Conn, @@ -610,21 +660,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errs <-chan *interfaces.ErrorMessage, wsBodyLog *strings.Builder, sessionID string, -) ([]byte, error) { +) ([]byte, int, error) { completed := false completedOutput := []byte("[]") + terminalStatusCode := 0 for { select { case <-c.Request.Context().Done(): cancel(c.Request.Context().Err()) - return completedOutput, c.Request.Context().Err() + return completedOutput, terminalStatusCode, c.Request.Context().Err() case errMsg, ok := <-errs: if !ok { errs = nil continue } if errMsg != nil { + terminalStatusCode = errMsg.StatusCode h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) @@ -644,7 +696,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( // errWrite, // ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, terminalStatusCode, errWrite } } if errMsg != nil { @@ -652,7 +704,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( } else { cancel(nil) } - return completedOutput, nil + return completedOutput, terminalStatusCode, nil case chunk, ok := <-data: if !ok { if !completed { @@ -660,6 +712,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( StatusCode: http.StatusRequestTimeout, Error: fmt.Errorf("stream closed before response.completed"), } + terminalStatusCode = errMsg.StatusCode h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) @@ -679,13 +732,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, terminalStatusCode, errWrite } cancel(errMsg.Error) - return completedOutput, nil + return completedOutput, terminalStatusCode, nil } cancel(nil) - return completedOutput, nil + return completedOutput, terminalStatusCode, nil } payloads := websocketJSONPayloadsFromChunk(chunk) @@ -712,13 +765,22 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errWrite) - return completedOutput, errWrite + return completedOutput, terminalStatusCode, errWrite } } } } } +func shouldResetResponsesWebsocketAuthPin(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, http.StatusForbidden, http.StatusPaymentRequired, http.StatusUnauthorized: + return true + default: + return false + } +} + func responseCompletedOutputFromPayload(payload []byte) []byte { output := gjson.GetBytes(payload, "response.output") if output.Exists() && output.IsArray() { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index b3a32c5c9d..fc55a3ddf5 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -390,6 +390,34 @@ func TestSetWebsocketRequestBody(t *testing.T) { } } +func TestShouldResetResponsesWebsocketAuthPin(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + statusCode int + want bool + }{ + {name: "too_many_requests", statusCode: http.StatusTooManyRequests, want: true}, + {name: "forbidden", statusCode: http.StatusForbidden, want: true}, + {name: "payment_required", statusCode: http.StatusPaymentRequired, want: true}, + {name: "unauthorized", statusCode: http.StatusUnauthorized, want: true}, + {name: "internal_error", statusCode: http.StatusInternalServerError, want: false}, + {name: "zero", statusCode: 0, want: false}, + } + + for i := range cases { + tc := cases[i] + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := shouldResetResponsesWebsocketAuthPin(tc.statusCode) + if got != tc.want { + t.Fatalf("shouldResetResponsesWebsocketAuthPin(%d) = %v, want %v", tc.statusCode, got, tc.want) + } + }) + } +} + func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { gin.SetMode(gin.TestMode) @@ -417,7 +445,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { close(errCh) var bodyLog strings.Builder - completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( + completedOutput, statusCode, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( ctx, conn, func(...interface{}) {}, @@ -430,6 +458,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { serverErrCh <- err return } + if statusCode != 0 { + serverErrCh <- fmt.Errorf("status code = %d, want 0", statusCode) + return + } if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" { serverErrCh <- errors.New("completed output not captured") return