Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions sdk/api/handlers/openai/openai_responses_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
var lastRequest []byte
lastResponseOutput := []byte("[]")
pinnedAuthID := ""
resetPinnedAuthDisablesIncrementalInput := false

for {
msgType, payload, errReadMessage := conn.ReadMessage()
Expand All @@ -104,6 +105,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
// )
appendWebsocketEvent(&wsBodyLog, "request", payload)

disableIncrementalForThisTurn := resetPinnedAuthDisablesIncrementalInput
allowIncrementalInputWithPreviousResponseID := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
Expand All @@ -116,6 +118,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
if disableIncrementalForThisTurn {
allowIncrementalInputWithPreviousResponseID = false
}

var requestJSON []byte
var updatedLastRequest []byte
Expand Down Expand Up @@ -180,14 +185,28 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")

completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
completedOutput, errForward, resetPinnedAuth := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if resetPinnedAuth {
if strings.TrimSpace(pinnedAuthID) != "" {
log.Infof("responses websocket: reset pinned auth id=%s auth=%s", passthroughSessionID, strings.TrimSpace(pinnedAuthID))
}
pinnedAuthID = ""
resetPinnedAuthDisablesIncrementalInput = true
if h != nil && h.AuthManager != nil {
h.AuthManager.CloseExecutionSession(passthroughSessionID)
log.Infof("responses websocket: upstream execution session reset id=%s", passthroughSessionID)
}
}
if errForward != nil {
wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
lastResponseOutput = completedOutput
if disableIncrementalForThisTurn && !resetPinnedAuth {
resetPinnedAuthDisablesIncrementalInput = false
}
}
}

Expand Down Expand Up @@ -598,20 +617,21 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder,
sessionID string,
) ([]byte, error) {
) ([]byte, error, bool) {
completed := false
completedOutput := []byte("[]")

for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
return completedOutput, c.Request.Context().Err(), false
case errMsg, ok := <-errs:
if !ok {
errs = nil
continue
}
resetPinnedAuth := shouldResetPinnedAuthForWebsocketError(errMsg)
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
Expand All @@ -632,15 +652,15 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, errWrite, resetPinnedAuth
}
}
if errMsg != nil {
cancel(errMsg.Error)
} else {
cancel(nil)
}
return completedOutput, nil
return completedOutput, nil, resetPinnedAuth
case chunk, ok := <-data:
if !ok {
if !completed {
Expand All @@ -667,13 +687,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, errWrite, false
}
cancel(errMsg.Error)
return completedOutput, nil
return completedOutput, nil, false
}
cancel(nil)
return completedOutput, nil
return completedOutput, nil, false
}

payloads := websocketJSONPayloadsFromChunk(chunk)
Expand All @@ -700,13 +720,36 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
return completedOutput, errWrite, false
}
}
}
}
}

func shouldResetPinnedAuthForWebsocketError(errMsg *interfaces.ErrorMessage) bool {
if errMsg == nil {
return false
}
switch errMsg.StatusCode {
case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusNotFound, http.StatusTooManyRequests:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid resetting pinned auth on generic 404 responses

shouldResetPinnedAuthForWebsocketError currently treats any 404 as an unrecoverable auth error, but websocket upstreams can return 404 for request/state issues (for example, stale previous_response_id) that are not credential failures. In that case this branch clears pinnedAuthID and closes the execution session, so the next turn can be routed to a different auth and lose session continuity even if the client corrects the request. This makes a recoverable request error degrade into cross-auth context loss.

Useful? React with 👍 / 👎.

return true
}
if errMsg.Error == nil {
return false
}
text := strings.ToLower(strings.TrimSpace(errMsg.Error.Error()))
if text == "" {
return false
}
return strings.Contains(text, "usage_limit_reached") ||
strings.Contains(text, "insufficient_quota") ||
strings.Contains(text, "quota exceeded") ||
strings.Contains(text, "quota exhausted") ||
strings.Contains(text, "token_invalidated") ||
strings.Contains(text, "authentication token has been invalidated")
}

func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
Expand Down
Loading