Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions sdk/api/handlers/openai/openai_responses_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")

completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
completedOutput, errForward, resetPinnedAuth := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if resetPinnedAuth {
if strings.TrimSpace(pinnedAuthID) != "" {
log.Infof("responses websocket: reset pinned auth id=%s auth=%s", passthroughSessionID, strings.TrimSpace(pinnedAuthID))
}
pinnedAuthID = ""
if h != nil && h.AuthManager != nil {
h.AuthManager.CloseExecutionSession(passthroughSessionID)
log.Infof("responses websocket: upstream execution session reset id=%s", passthroughSessionID)
}
}
if errForward != nil {
wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
Expand Down Expand Up @@ -598,20 +608,21 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder,
sessionID string,
) ([]byte, error) {
) ([]byte, error, bool) {
completed := false
completedOutput := []byte("[]")

for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
return completedOutput, c.Request.Context().Err(), false
case errMsg, ok := <-errs:
if !ok {
errs = nil
continue
}
resetPinnedAuth := shouldResetPinnedAuthForWebsocketError(errMsg)
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
Expand All @@ -632,15 +643,15 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, errWrite, resetPinnedAuth
}
}
if errMsg != nil {
cancel(errMsg.Error)
} else {
cancel(nil)
}
return completedOutput, nil
return completedOutput, nil, resetPinnedAuth
case chunk, ok := <-data:
if !ok {
if !completed {
Expand All @@ -667,13 +678,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, errWrite, false
}
cancel(errMsg.Error)
return completedOutput, nil
return completedOutput, nil, false
}
cancel(nil)
return completedOutput, nil
return completedOutput, nil, false
}

payloads := websocketJSONPayloadsFromChunk(chunk)
Expand All @@ -700,13 +711,36 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
return completedOutput, errWrite, false
}
}
}
}
}

func shouldResetPinnedAuthForWebsocketError(errMsg *interfaces.ErrorMessage) bool {
if errMsg == nil {
return false
}
switch errMsg.StatusCode {
case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusNotFound, http.StatusTooManyRequests:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid resetting pinned auth on generic 404 responses

shouldResetPinnedAuthForWebsocketError currently treats any 404 as an unrecoverable auth error, but websocket upstreams can return 404 for request/state issues (for example, stale previous_response_id) that are not credential failures. In that case this branch clears pinnedAuthID and closes the execution session, so the next turn can be routed to a different auth and lose session continuity even if the client corrects the request. This makes a recoverable request error degrade into cross-auth context loss.

Useful? React with 👍 / 👎.

return true
}
if errMsg.Error == nil {
return false
}
text := strings.ToLower(strings.TrimSpace(errMsg.Error.Error()))
if text == "" {
return false
}
return strings.Contains(text, "usage_limit_reached") ||
strings.Contains(text, "insufficient_quota") ||
strings.Contains(text, "quota exceeded") ||
strings.Contains(text, "quota exhausted") ||
strings.Contains(text, "token_invalidated") ||
strings.Contains(text, "authentication token has been invalidated")
}

func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
Expand Down
54 changes: 53 additions & 1 deletion sdk/api/handlers/openai/openai_responses_websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
close(errCh)

var bodyLog strings.Builder
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
completedOutput, err, _ := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
Expand Down Expand Up @@ -492,3 +492,55 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
t.Fatalf("unexpected forwarded input: %s", forwarded)
}
}

func TestShouldResetPinnedAuthForWebsocketError(t *testing.T) {
cases := []struct {
name string
errMsg *interfaces.ErrorMessage
want bool
}{
{
name: "nil error message",
errMsg: nil,
want: false,
},
{
name: "status too many requests",
errMsg: &interfaces.ErrorMessage{
StatusCode: http.StatusTooManyRequests,
Error: errors.New(`{"error":{"type":"usage_limit_reached"}}`),
},
want: true,
},
{
name: "status forbidden",
errMsg: &interfaces.ErrorMessage{
StatusCode: http.StatusForbidden,
Error: errors.New(`{"error":{"code":"token_invalidated"}}`),
},
want: true,
},
{
name: "textual quota signal",
errMsg: &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: errors.New("insufficient_quota"),
},
want: true,
},
{
name: "non credential error",
errMsg: &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: errors.New("invalid_request_error"),
},
want: false,
},
}

for _, tc := range cases {
if got := shouldResetPinnedAuthForWebsocketError(tc.errMsg); got != tc.want {
t.Fatalf("%s: got %v want %v", tc.name, got, tc.want)
}
}
}