Skip to content
35 changes: 35 additions & 0 deletions internal/runtime/executor/codex_websockets_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) {
s.activeMu.Unlock()
}

func (s *codexWebsocketSession) isCurrentConn(conn *websocket.Conn) bool {
if s == nil || conn == nil {
return false
}
s.connMu.Lock()
current := s.conn
s.connMu.Unlock()
return current == conn
}

func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error {
if s == nil {
return fmt.Errorf("codex websockets executor: session is nil")
Expand Down Expand Up @@ -1067,12 +1077,21 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
if sess == nil {
return e.dialCodexWebsocket(ctx, auth, wsURL, headers)
}
authID = strings.TrimSpace(authID)

sess.connMu.Lock()
conn := sess.conn
readerConn := sess.readerConn
currentAuthID := strings.TrimSpace(sess.authID)
sess.connMu.Unlock()
if conn != nil && currentAuthID != authID {
// 账号切换时先断开旧连接避免继续复用旧账号
e.invalidateUpstreamConn(sess, conn, "auth_switched", nil)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid tearing down active reader channel on auth switch

Calling invalidateUpstreamConn during auth changes closes the old socket immediately, but the old readUpstreamLoop can still wake later and run its error path, which unconditionally clearActive/closes whatever channel is currently active for the session. If a new request has already installed its own activeCh, that channel gets closed and readCodexWebsocketMessage returns session read channel closed, aborting a healthy post-switch request. This race is specific to in-session auth switching and can break the new quota-recovery flow intermittently.

Useful? React with 👍 / 👎.

conn = nil
readerConn = nil
}
if conn != nil {
// 账号未变化时复用连接减少不必要重连
if readerConn != conn {
sess.connMu.Lock()
sess.readerConn = conn
Expand Down Expand Up @@ -1114,9 +1133,17 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession,
return
}
for {
if !sess.isCurrentConn(conn) {
// 旧连接读循环直接退出避免误伤新请求通道
return
}
_ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout))
msgType, payload, errRead := conn.ReadMessage()
if errRead != nil {
if !sess.isCurrentConn(conn) {
// 旧连接读错时不触碰当前活跃通道
return
}
Comment on lines +1222 to +1225

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Notify active reader when old conn is force-closed

Do not return immediately on !current here, because closeExecutionSession sets sess.conn = nil before closing the socket; when that close unblocks ReadMessage, this branch exits without sending an error to the active channel. In that case readCodexWebsocketMessage keeps waiting on readCh for the in-flight request until its context is canceled, so explicit session shutdown (for example CloseExecutionSession/executor replacement paths) can hang requests instead of failing fast.

Useful? React with 👍 / 👎.

sess.activeMu.Lock()
ch := sess.activeCh
done := sess.activeDone
Expand All @@ -1137,6 +1164,10 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession,
if msgType != websocket.TextMessage {
if msgType == websocket.BinaryMessage {
errBinary := fmt.Errorf("codex websockets executor: unexpected binary message")
if !sess.isCurrentConn(conn) {
// 旧连接二进制异常时不触碰当前活跃通道
return
}
sess.activeMu.Lock()
ch := sess.activeCh
done := sess.activeDone
Expand All @@ -1155,6 +1186,10 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession,
}
continue
}
if !sess.isCurrentConn(conn) {
// 旧连接消息不再分发给新连接请求
return
}

sess.activeMu.Lock()
ch := sess.activeCh
Expand Down
64 changes: 46 additions & 18 deletions sdk/api/handlers/openai/openai_responses_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
var lastRequest []byte
lastResponseOutput := []byte("[]")
pinnedAuthID := ""
forceDisableIncrementalAfterAuthReset := false

for {
msgType, payload, errReadMessage := conn.ReadMessage()
Expand All @@ -107,16 +108,18 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
appendWebsocketEvent(&wsBodyLog, "request", payload)

allowIncrementalInputWithPreviousResponseID := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if !forceDisableIncrementalAfterAuthReset {
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}

var requestJSON []byte
Expand Down Expand Up @@ -192,13 +195,26 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")

completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
completedOutput, terminalStatus, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
if shouldResetResponsesWebsocketAuthPin(terminalStatus) {
// 限额错误后解除 pin 让后续请求重新选可用账号
pinnedAuthID = ""
// 切号恢复阶段先禁用增量模式避免沿用旧账号 response id
forceDisableIncrementalAfterAuthReset = true

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve transcript when forcing non-incremental recovery

Setting forceDisableIncrementalAfterAuthReset = true after quota-class errors makes the next turn fall back to normalizeResponseSubsequentRequest's non-incremental merge path, but the in-memory state is still incremental-form state (lastRequest was already replaced by the previous incremental payload and failed turns leave lastResponseOutput as []). In websocket v2 sessions that send only delta input with previous_response_id, this causes the post-switch request to be built from a truncated history (often just the last user delta), so recovery succeeds but silently drops earlier conversation context.

Useful? React with 👍 / 👎.

if h != nil && h.AuthManager != nil {
Comment on lines +225 to +231

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reset incremental state when unpinning after quota errors

After a quota-class terminal error this block clears pinnedAuthID and closes the execution session, but it leaves request state unchanged; on the next turn, allowIncrementalInputWithPreviousResponseID can still be true (model-level), so normalizeResponseSubsequentRequest may forward a client previous_response_id that belongs to the old account to the newly selected account. In websocket v2 flows that send previous_response_id, this causes the post-switch request to fail with upstream 4xx instead of recovering, so the auth-rotation fix is incomplete unless incremental mode is disabled/reset for at least the next turn.

Useful? React with 👍 / 👎.

// 主动关闭旧上游会话避免继续复用旧账号连接
h.AuthManager.CloseExecutionSession(passthroughSessionID)
}
} else if forceDisableIncrementalAfterAuthReset && terminalStatus == 0 {
// 新账号完成一轮后恢复增量模式
forceDisableIncrementalAfterAuthReset = false
}
lastResponseOutput = completedOutput
}
}
Expand Down Expand Up @@ -610,21 +626,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder,
sessionID string,
) ([]byte, error) {
) ([]byte, int, error) {
completed := false
completedOutput := []byte("[]")
terminalStatusCode := 0

for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
return completedOutput, terminalStatusCode, c.Request.Context().Err()
case errMsg, ok := <-errs:
if !ok {
errs = nil
continue
}
if errMsg != nil {
terminalStatusCode = errMsg.StatusCode
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
Expand All @@ -644,22 +662,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, terminalStatusCode, errWrite
}
}
if errMsg != nil {
cancel(errMsg.Error)
} else {
cancel(nil)
}
return completedOutput, nil
return completedOutput, terminalStatusCode, nil
case chunk, ok := <-data:
if !ok {
if !completed {
errMsg := &interfaces.ErrorMessage{
StatusCode: http.StatusRequestTimeout,
Error: fmt.Errorf("stream closed before response.completed"),
}
terminalStatusCode = errMsg.StatusCode
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
Expand All @@ -679,13 +698,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, terminalStatusCode, errWrite
}
cancel(errMsg.Error)
return completedOutput, nil
return completedOutput, terminalStatusCode, nil
}
cancel(nil)
return completedOutput, nil
return completedOutput, terminalStatusCode, nil
}

payloads := websocketJSONPayloadsFromChunk(chunk)
Expand All @@ -712,13 +731,22 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
return completedOutput, terminalStatusCode, errWrite
}
}
}
}
}

func shouldResetResponsesWebsocketAuthPin(statusCode int) bool {
switch statusCode {
case http.StatusTooManyRequests, http.StatusForbidden, http.StatusPaymentRequired:
return true
default:
return false
}
}

func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
Expand Down
6 changes: 5 additions & 1 deletion sdk/api/handlers/openai/openai_responses_websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
close(errCh)

var bodyLog strings.Builder
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
completedOutput, statusCode, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
Expand All @@ -430,6 +430,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
serverErrCh <- err
return
}
if statusCode != 0 {
serverErrCh <- fmt.Errorf("status code = %d, want 0", statusCode)
return
}
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
serverErrCh <- errors.New("completed output not captured")
return
Expand Down
Loading