Skip to content
166 changes: 141 additions & 25 deletions internal/runtime/executor/codex_websockets_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,31 @@ type codexWebsocketRead struct {
err error
}

func trySendCodexWebsocketRead(ch chan codexWebsocketRead, done <-chan struct{}, ev codexWebsocketRead) {
if ch == nil {
return
}
defer func() { _ = recover() }()
select {
case ch <- ev:
case <-done:
default:
}
}

func tryCloseCodexWebsocketRead(ch chan codexWebsocketRead) {
if ch == nil {
return
}
defer func() { _ = recover() }()
close(ch)
}

func (s *codexWebsocketSession) setActive(ch chan codexWebsocketRead) {
if s == nil {
return
}
// 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突
s.activeMu.Lock()
if s.activeCancel != nil {
s.activeCancel()
Expand All @@ -105,6 +126,7 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) {
if s == nil {
return
}
// 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突
s.activeMu.Lock()
if s.activeCh == ch {
s.activeCh = nil
Expand All @@ -117,6 +139,85 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) {
s.activeMu.Unlock()
}

func (s *codexWebsocketSession) isCurrentConn(conn *websocket.Conn) bool {
if s == nil || conn == nil {
return false
}
s.connMu.Lock()
current := s.conn
s.connMu.Unlock()
return current == conn
}

func (s *codexWebsocketSession) activeSnapshotForCurrentConn(conn *websocket.Conn) (chan codexWebsocketRead, <-chan struct{}, bool) {
if s == nil || conn == nil {
return nil, nil, false
}
// 锁顺序固定为 connMu -> activeMu
s.connMu.Lock()
if s.conn != conn {
s.connMu.Unlock()
return nil, nil, false
}
s.activeMu.Lock()
ch := s.activeCh
done := s.activeDone
s.activeMu.Unlock()
s.connMu.Unlock()
return ch, done, true
}

func (s *codexWebsocketSession) clearActiveForCurrentConn(conn *websocket.Conn, ch chan codexWebsocketRead) bool {
if s == nil || conn == nil || ch == nil {
return false
}
// 锁顺序固定为 connMu -> activeMu
s.connMu.Lock()
if s.conn != conn {
s.connMu.Unlock()
return false
}
s.activeMu.Lock()
if s.activeCh != ch {
s.activeMu.Unlock()
s.connMu.Unlock()
return false
}
s.activeCh = nil
if s.activeCancel != nil {
s.activeCancel()
}
s.activeCancel = nil
s.activeDone = nil
s.activeMu.Unlock()
s.connMu.Unlock()
return true
}

func (s *codexWebsocketSession) failActiveForSessionClose(conn *websocket.Conn, err error) {
if s == nil {
return
}
if err == nil {
err = fmt.Errorf("codex websockets executor: execution session closed")
}
s.activeMu.Lock()
ch := s.activeCh
done := s.activeDone
if s.activeCancel != nil {
s.activeCancel()
}
s.activeCh = nil
s.activeCancel = nil
s.activeDone = nil
s.activeMu.Unlock()
if ch == nil {
return
}
trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: err})
tryCloseCodexWebsocketRead(ch)
}

func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error {
if s == nil {
return fmt.Errorf("codex websockets executor: session is nil")
Expand Down Expand Up @@ -1064,15 +1165,24 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
}

func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
authID = strings.TrimSpace(authID)
if sess == nil {
return e.dialCodexWebsocket(ctx, auth, wsURL, headers)
}

sess.connMu.Lock()
conn := sess.conn
readerConn := sess.readerConn
currentAuthID := strings.TrimSpace(sess.authID)
sess.connMu.Unlock()
if conn != nil && currentAuthID != authID {
// 账号切换时先断开旧连接避免继续复用旧账号
e.invalidateUpstreamConn(sess, conn, "auth_switched", nil)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid tearing down active reader channel on auth switch

Calling invalidateUpstreamConn during auth changes closes the old socket immediately, but the old readUpstreamLoop can still wake later and run its error path, which unconditionally clearActive/closes whatever channel is currently active for the session. If a new request has already installed its own activeCh, that channel gets closed and readCodexWebsocketMessage returns session read channel closed, aborting a healthy post-switch request. This race is specific to in-session auth switching and can break the new quota-recovery flow intermittently.

Useful? React with 👍 / 👎.

conn = nil
readerConn = nil
}
if conn != nil {
// 账号未变化时复用连接减少不必要重连
if readerConn != conn {
sess.connMu.Lock()
sess.readerConn = conn
Expand Down Expand Up @@ -1114,21 +1224,24 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession,
return
}
for {
if !sess.isCurrentConn(conn) {
// 旧连接读循环直接退出避免误伤新请求通道
return
}
_ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout))
msgType, payload, errRead := conn.ReadMessage()
if errRead != nil {
sess.activeMu.Lock()
ch := sess.activeCh
done := sess.activeDone
sess.activeMu.Unlock()
// 在同一临界区做归属校验和通道快照避免检查后竞态
ch, done, current := sess.activeSnapshotForCurrentConn(conn)
if !current {
// 旧连接读错时不触碰当前活跃通道
return
}
Comment on lines +1222 to +1225

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Notify active reader when old conn is force-closed

Do not return immediately on !current here, because closeExecutionSession sets sess.conn = nil before closing the socket; when that close unblocks ReadMessage, this branch exits without sending an error to the active channel. In that case readCodexWebsocketMessage keeps waiting on readCh for the in-flight request until its context is canceled, so explicit session shutdown (for example CloseExecutionSession/executor replacement paths) can hang requests instead of failing fast.

Useful? React with 👍 / 👎.

if ch != nil {
select {
case ch <- codexWebsocketRead{conn: conn, err: errRead}:
case <-done:
default:
trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: errRead})
if sess.clearActiveForCurrentConn(conn, ch) {
tryCloseCodexWebsocketRead(ch)
}
sess.clearActive(ch)
close(ch)
}
e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead)
return
Expand All @@ -1137,29 +1250,29 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession,
if msgType != websocket.TextMessage {
if msgType == websocket.BinaryMessage {
errBinary := fmt.Errorf("codex websockets executor: unexpected binary message")
sess.activeMu.Lock()
ch := sess.activeCh
done := sess.activeDone
sess.activeMu.Unlock()
// 在同一临界区做归属校验和通道快照避免检查后竞态
ch, done, current := sess.activeSnapshotForCurrentConn(conn)
if !current {
// 旧连接二进制异常时不触碰当前活跃通道
return
}
if ch != nil {
select {
case ch <- codexWebsocketRead{conn: conn, err: errBinary}:
case <-done:
default:
trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: errBinary})
if sess.clearActiveForCurrentConn(conn, ch) {
tryCloseCodexWebsocketRead(ch)
}
sess.clearActive(ch)
close(ch)
}
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary)
return
}
continue
}

sess.activeMu.Lock()
ch := sess.activeCh
done := sess.activeDone
sess.activeMu.Unlock()
// 在同一临界区做归属校验和通道快照避免检查后竞态
ch, done, current := sess.activeSnapshotForCurrentConn(conn)
if !current {
// 旧连接消息不再分发给新连接请求
return
}
if ch == nil {
continue
}
Expand Down Expand Up @@ -1257,6 +1370,9 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess
sessionID := sess.sessionID
sess.connMu.Unlock()

// 会话显式关闭时主动唤醒活跃请求避免 readCh 悬挂
sess.failActiveForSessionClose(conn, fmt.Errorf("codex websockets executor: execution session closed"))

if conn == nil {
return
}
Expand Down
88 changes: 88 additions & 0 deletions internal/runtime/executor/codex_websockets_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
Expand Down Expand Up @@ -201,3 +204,88 @@ func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
t.Fatal("expected websocket proxy function to be nil for direct mode")
}
}

func TestReadCodexWebsocketMessageReturnsWhenReadChannelClosed(t *testing.T) {
t.Parallel()

sess := &codexWebsocketSession{}
conn := &websocket.Conn{}
readCh := make(chan codexWebsocketRead)
close(readCh)

_, _, err := readCodexWebsocketMessage(context.Background(), sess, conn, readCh)
if err == nil {
t.Fatal("expected error when session read channel is closed")
}
if !strings.Contains(err.Error(), "session read channel closed") {
t.Fatalf("error = %v, want contains session read channel closed", err)
}
}

func TestCloseExecutionSessionUnblocksActiveRead(t *testing.T) {
t.Parallel()

upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
serverConnCh := make(chan *websocket.Conn, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
serverConnCh <- conn
_, _, _ = conn.ReadMessage()
}))
defer server.Close()

wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
clientConn, _, errDial := websocket.DefaultDialer.Dial(wsURL, nil)
if errDial != nil {
t.Fatalf("dial websocket: %v", errDial)
}
defer func() { _ = clientConn.Close() }()

var serverConn *websocket.Conn
select {
case serverConn = <-serverConnCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for server websocket connection")
}

sess := &codexWebsocketSession{
sessionID: "session-close",
conn: serverConn,
readerConn: serverConn,
}
readCh := make(chan codexWebsocketRead, 4)
sess.setActive(readCh)

executor := &CodexWebsocketsExecutor{
CodexExecutor: &CodexExecutor{},
sessions: map[string]*codexWebsocketSession{
"session-close": sess,
},
}

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
readErrCh := make(chan error, 1)
go func() {
_, _, err := readCodexWebsocketMessage(ctx, sess, serverConn, readCh)
readErrCh <- err
}()

executor.CloseExecutionSession("session-close")

select {
case err := <-readErrCh:
if err == nil {
t.Fatal("expected read error after closing execution session")
}
errText := err.Error()
if !strings.Contains(errText, "execution session closed") && !strings.Contains(errText, "session read channel closed") {
t.Fatalf("error = %v, want fast-fail error from session close path", err)
}
case <-time.After(3 * time.Second):
t.Fatal("read did not fail fast after closeExecutionSession")
}
}
Loading
Loading