diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 6a444b45fa..5ba8fdf0ac 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -79,6 +79,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { var lastRequest []byte lastResponseOutput := []byte("[]") pinnedAuthID := "" + resetPinnedAuthDisablesIncrementalInput := false for { msgType, payload, errReadMessage := conn.ReadMessage() @@ -104,6 +105,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { // ) appendWebsocketEvent(&wsBodyLog, "request", payload) + disableIncrementalForThisTurn := resetPinnedAuthDisablesIncrementalInput allowIncrementalInputWithPreviousResponseID := false if pinnedAuthID != "" && h != nil && h.AuthManager != nil { if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { @@ -116,6 +118,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) } + if disableIncrementalForThisTurn { + allowIncrementalInputWithPreviousResponseID = false + } var requestJSON []byte var updatedLastRequest []byte @@ -180,7 +185,18 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") - completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) + completedOutput, errForward, resetPinnedAuth := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) + if resetPinnedAuth { + if strings.TrimSpace(pinnedAuthID) != "" { + log.Infof("responses websocket: reset pinned auth id=%s auth=%s", passthroughSessionID, strings.TrimSpace(pinnedAuthID)) + } + pinnedAuthID = "" + resetPinnedAuthDisablesIncrementalInput = true + if h != nil && h.AuthManager != nil { + h.AuthManager.CloseExecutionSession(passthroughSessionID) + log.Infof("responses websocket: upstream execution session reset id=%s", passthroughSessionID) + } + } if errForward != nil { wsTerminateErr = errForward appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error())) @@ -188,6 +204,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { return } lastResponseOutput = completedOutput + if disableIncrementalForThisTurn && !resetPinnedAuth { + resetPinnedAuthDisablesIncrementalInput = false + } } } @@ -598,7 +617,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errs <-chan *interfaces.ErrorMessage, wsBodyLog *strings.Builder, sessionID string, -) ([]byte, error) { +) ([]byte, error, bool) { completed := false completedOutput := []byte("[]") @@ -606,12 +625,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( select { case <-c.Request.Context().Done(): cancel(c.Request.Context().Err()) - return completedOutput, c.Request.Context().Err() + return completedOutput, c.Request.Context().Err(), false case errMsg, ok := <-errs: if !ok { errs = nil continue } + resetPinnedAuth := shouldResetPinnedAuthForWebsocketError(errMsg) if errMsg != nil { h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) @@ -632,7 +652,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( // errWrite, // ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, errWrite, resetPinnedAuth } } if errMsg != nil { @@ -640,7 +660,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( } else { cancel(nil) } - return completedOutput, nil + return completedOutput, nil, resetPinnedAuth case chunk, ok := <-data: if !ok { if !completed { @@ -667,13 +687,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, errWrite, false } cancel(errMsg.Error) - return completedOutput, nil + return completedOutput, nil, false } cancel(nil) - return completedOutput, nil + return completedOutput, nil, false } payloads := websocketJSONPayloadsFromChunk(chunk) @@ -700,13 +720,36 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errWrite) - return completedOutput, errWrite + return completedOutput, errWrite, false } } } } } +func shouldResetPinnedAuthForWebsocketError(errMsg *interfaces.ErrorMessage) bool { + if errMsg == nil { + return false + } + switch errMsg.StatusCode { + case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusNotFound, http.StatusTooManyRequests: + return true + } + if errMsg.Error == nil { + return false + } + text := strings.ToLower(strings.TrimSpace(errMsg.Error.Error())) + if text == "" { + return false + } + return strings.Contains(text, "usage_limit_reached") || + strings.Contains(text, "insufficient_quota") || + strings.Contains(text, "quota exceeded") || + strings.Contains(text, "quota exhausted") || + strings.Contains(text, "token_invalidated") || + strings.Contains(text, "authentication token has been invalidated") +} + func responseCompletedOutputFromPayload(payload []byte) []byte { output := gjson.GetBytes(payload, "response.output") if output.Exists() && output.IsArray() { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index d30c648d9e..19921ea56d 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "github.com/gin-gonic/gin" @@ -26,6 +27,76 @@ type websocketCaptureExecutor struct { payloads [][]byte } +type websocketStatusErr struct { + code int + msg string +} + +func (e websocketStatusErr) Error() string { return e.msg } +func (e websocketStatusErr) StatusCode() int { return e.code } + +type websocketSequenceStep struct { + wantAuthID string + responseID string + outputID string + err error +} + +type websocketSequenceExecutor struct { + id string + t *testing.T + + mu sync.Mutex + streamCalls int + payloads [][]byte + authIDs []string + steps []websocketSequenceStep +} + +func (e *websocketSequenceExecutor) Identifier() string { return e.id } + +func (e *websocketSequenceExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketSequenceExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { + e.mu.Lock() + defer e.mu.Unlock() + idx := e.streamCalls + if idx >= len(e.steps) { + e.t.Fatalf("unexpected ExecuteStream call #%d for auth %s", idx+1, auth.ID) + } + step := e.steps[idx] + e.streamCalls++ + e.authIDs = append(e.authIDs, auth.ID) + e.payloads = append(e.payloads, bytes.Clone(req.Payload)) + if step.wantAuthID != "" && auth.ID != step.wantAuthID { + e.t.Fatalf("ExecuteStream call #%d auth = %s, want %s", idx+1, auth.ID, step.wantAuthID) + } + chunks := make(chan coreexecutor.StreamChunk, 1) + if step.err != nil { + chunks <- coreexecutor.StreamChunk{Err: step.err} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil + } + payload := fmt.Sprintf(`{"type":"response.completed","response":{"id":%q,"output":[{"type":"message","id":%q}]}}`, step.responseID, step.outputID) + chunks <- coreexecutor.StreamChunk{Payload: []byte(payload)} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *websocketSequenceExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketSequenceExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketSequenceExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -317,7 +388,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { close(errCh) var bodyLog strings.Builder - completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( + completedOutput, err, _ := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( ctx, conn, func(...interface{}) {}, @@ -492,3 +563,157 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) { t.Fatalf("unexpected forwarded input: %s", forwarded) } } + +func TestResponsesWebsocket_DropsPreviousResponseIDAfterPinnedAuthReset(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &websocketSequenceExecutor{ + id: "test-provider", + t: t, + steps: []websocketSequenceStep{ + {wantAuthID: "auth-a", responseID: "resp-a-1", outputID: "out-a-1"}, + {wantAuthID: "auth-a", err: websocketStatusErr{code: http.StatusForbidden, msg: `{"error":{"code":"token_invalidated"}}`}}, + {wantAuthID: "auth-b", responseID: "resp-b-1", outputID: "out-b-1"}, + }, + } + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + authA := &coreauth.Auth{ID: "auth-a", Provider: executor.Identifier(), Status: coreauth.StatusActive, Attributes: map[string]string{"websockets": "true"}} + authB := &coreauth.Auth{ID: "auth-b", Provider: executor.Identifier(), Status: coreauth.StatusActive, Attributes: map[string]string{"websockets": "true"}} + if _, err := manager.Register(context.Background(), authA); err != nil { + t.Fatalf("Register authA: %v", err) + } + if _, err := manager.Register(context.Background(), authB); err != nil { + t.Fatalf("Register authB: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(authA.ID, authA.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(authB.ID, authB.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(authA.ID) + registry.GetGlobalRegistry().UnregisterClient(authB.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + if errClose := conn.Close(); errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`)); errWrite != nil { + t.Fatalf("write first websocket message: %v", errWrite) + } + _, payload, errRead := conn.ReadMessage() + if errRead != nil { + t.Fatalf("read first websocket response: %v", errRead) + } + if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted { + t.Fatalf("first payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted) + } + previousResponseID := gjson.GetBytes(payload, "response.id").String() + if previousResponseID != "resp-a-1" { + t.Fatalf("first response id = %s, want resp-a-1", previousResponseID) + } + + secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-2"}]}`, previousResponseID) + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(secondRequest)); errWrite != nil { + t.Fatalf("write second websocket message: %v", errWrite) + } + _, payload, errRead = conn.ReadMessage() + if errRead != nil { + t.Fatalf("read second websocket response: %v", errRead) + } + if !gjson.GetBytes(payload, "error").Exists() { + t.Fatalf("expected websocket error payload after pinned auth failure, got %s", payload) + } + + thirdRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-3"}]}`, previousResponseID) + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(thirdRequest)); errWrite != nil { + t.Fatalf("write third websocket message: %v", errWrite) + } + _, payload, errRead = conn.ReadMessage() + if errRead != nil { + t.Fatalf("read third websocket response: %v", errRead) + } + if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted { + t.Fatalf("third payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted) + } + + if executor.streamCalls != 3 { + t.Fatalf("stream calls = %d, want 3", executor.streamCalls) + } + if len(executor.authIDs) != 3 || executor.authIDs[0] != "auth-a" || executor.authIDs[1] != "auth-a" || executor.authIDs[2] != "auth-b" { + t.Fatalf("unexpected auth selection sequence: %#v", executor.authIDs) + } + if !gjson.GetBytes(executor.payloads[1], "previous_response_id").Exists() { + t.Fatalf("second upstream request should preserve previous_response_id while pin is valid: %s", executor.payloads[1]) + } + if gjson.GetBytes(executor.payloads[2], "previous_response_id").Exists() { + t.Fatalf("third upstream request leaked previous_response_id after pin reset: %s", executor.payloads[2]) + } +} + +func TestShouldResetPinnedAuthForWebsocketError(t *testing.T) { + cases := []struct { + name string + errMsg *interfaces.ErrorMessage + want bool + }{ + { + name: "nil error message", + errMsg: nil, + want: false, + }, + { + name: "status too many requests", + errMsg: &interfaces.ErrorMessage{ + StatusCode: http.StatusTooManyRequests, + Error: errors.New(`{"error":{"type":"usage_limit_reached"}}`), + }, + want: true, + }, + { + name: "status forbidden", + errMsg: &interfaces.ErrorMessage{ + StatusCode: http.StatusForbidden, + Error: errors.New(`{"error":{"code":"token_invalidated"}}`), + }, + want: true, + }, + { + name: "textual quota signal", + errMsg: &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: errors.New("insufficient_quota"), + }, + want: true, + }, + { + name: "non credential error", + errMsg: &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: errors.New("invalid_request_error"), + }, + want: false, + }, + } + + for _, tc := range cases { + if got := shouldResetPinnedAuthForWebsocketError(tc.errMsg); got != tc.want { + t.Fatalf("%s: got %v want %v", tc.name, got, tc.want) + } + } +}