Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
448 changes: 0 additions & 448 deletions coverage-amp-target.out

This file was deleted.

450 changes: 0 additions & 450 deletions coverage-core-auth-proxy.out

This file was deleted.

5,681 changes: 0 additions & 5,681 deletions coverage-critical.out

This file was deleted.

448 changes: 0 additions & 448 deletions coverage.audit.out

This file was deleted.

37 changes: 27 additions & 10 deletions internal/api/handlers/management/auth_files.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ const (

type callbackForwarder struct {
provider string
target string
server *http.Server
done chan struct{}
refs int
}

var (
Expand Down Expand Up @@ -121,14 +123,18 @@ func isAddressAlreadyInUse(err error) bool {

func startCallbackForwarder(port int, provider, targetBase string) (*callbackForwarder, error) {
callbackForwardersMu.Lock()
prev := callbackForwarders[port]
if prev != nil {
if existing := callbackForwarders[port]; existing != nil {
if existing.provider == provider && existing.target == targetBase {
existing.refs++
callbackForwardersMu.Unlock()
log.Infof("callback forwarder for %s reusing %s (refs=%d)", provider, fmt.Sprintf("127.0.0.1:%d", port), existing.refs)
return existing, nil
}
delete(callbackForwarders, port)
}
callbackForwardersMu.Unlock()

if prev != nil {
stopForwarderInstance(port, prev)
callbackForwardersMu.Unlock()
stopForwarderInstance(port, existing)
} else {
callbackForwardersMu.Unlock()
}

addr := fmt.Sprintf("127.0.0.1:%d", port)
Expand Down Expand Up @@ -166,8 +172,10 @@ func startCallbackForwarder(port int, provider, targetBase string) (*callbackFor

forwarder := &callbackForwarder{
provider: provider,
target: targetBase,
server: srv,
done: done,
refs: 1,
}

callbackForwardersMu.Lock()
Expand All @@ -184,12 +192,21 @@ func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
return
}
callbackForwardersMu.Lock()
if current := callbackForwarders[port]; current == forwarder {
delete(callbackForwarders, port)
current := callbackForwarders[port]
if current != forwarder {
callbackForwardersMu.Unlock()
return
}
if current.refs > 1 {
current.refs--
callbackForwardersMu.Unlock()
log.Infof("callback forwarder for %s on port %d retained for shared sessions (refs=%d)", current.provider, port, current.refs)
return
}
delete(callbackForwarders, port)
callbackForwardersMu.Unlock()

stopForwarderInstance(port, forwarder)
stopForwarderInstance(port, current)
}

func stopForwarderInstance(port int, forwarder *callbackForwarder) {
Expand Down
81 changes: 81 additions & 0 deletions internal/api/handlers/management/auth_files_codex_webui_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gin-gonic/gin"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
Expand Down Expand Up @@ -52,3 +53,83 @@ func TestRequestCodexToken_WebUIPortInUseReturnsConflict(t *testing.T) {
t.Fatalf("error message = %q, want contains callback port", body.Error)
}
}

func TestCallbackForwarderSharedLifecycleKeepsListenerUntilLastRelease(t *testing.T) {
port := reserveFreePort(t)
target := "http://127.0.0.1:28317/codex/callback"

forwarderA, err := startCallbackForwarder(port, "codex", target)
if err != nil {
t.Fatalf("startCallbackForwarder(first): %v", err)
}
t.Cleanup(func() {
stopCallbackForwarderInstance(port, forwarderA)
})

forwarderB, err := startCallbackForwarder(port, "codex", target)
if err != nil {
t.Fatalf("startCallbackForwarder(second): %v", err)
}
if forwarderA != forwarderB {
t.Fatalf("expected shared forwarder instance, got %p and %p", forwarderA, forwarderB)
}

assertForwarderResponds(t, port)

stopCallbackForwarderInstance(port, forwarderA)
assertForwarderResponds(t, port)

stopCallbackForwarderInstance(port, forwarderB)
assertForwarderStops(t, port)
}

func reserveFreePort(t *testing.T) int {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen(reserveFreePort): %v", err)
}
defer listener.Close()

addr, ok := listener.Addr().(*net.TCPAddr)
if !ok {
t.Fatalf("listener.Addr() = %T, want *net.TCPAddr", listener.Addr())
}
return addr.Port
}

func assertForwarderResponds(t *testing.T, port int) {
t.Helper()

client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: 2 * time.Second,
}
resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/auth/callback?state=test", port))
if err != nil {
t.Fatalf("GET forwarder: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusFound {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusFound)
}
}

func assertForwarderStops(t *testing.T, port int) {
t.Helper()

deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 200*time.Millisecond)
if err != nil {
return
}
_ = conn.Close()
time.Sleep(50 * time.Millisecond)
}

t.Fatalf("port %d still accepts connections after releasing the last forwarder reference", port)
}
Binary file removed mock_upstream
Binary file not shown.
Loading