diff --git a/internal/auth/claude/proxy_dialer.go b/internal/auth/claude/proxy_dialer.go new file mode 100644 index 0000000000..d0ddb3448c --- /dev/null +++ b/internal/auth/claude/proxy_dialer.go @@ -0,0 +1,154 @@ +// Package claude provides authentication functionality for Anthropic's Claude API. +// This file implements proxy dialer construction for HTTP CONNECT and SOCKS5 proxies, +// used by the utls transport to route OAuth refresh requests through a configured proxy. +package claude + +import ( + "bufio" + cryptotls "crypto/tls" + "encoding/base64" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/proxy" +) + +// dialerFunc adapts a plain function to the proxy.Dialer interface. +type dialerFunc func(network, addr string) (net.Conn, error) + +func (f dialerFunc) Dial(network, addr string) (net.Conn, error) { return f(network, addr) } + +// bufferedConn wraps a net.Conn with a bufio.Reader so that any bytes +// pre-fetched during HTTP response parsing are returned before reading +// directly from the underlying connection. +type bufferedConn struct { + net.Conn + reader *bufio.Reader +} + +func (c *bufferedConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +// buildProxyDialer creates a proxy.Dialer for the given proxy URL string. +// It supports socks5, socks5h, http, and https schemes. +// An empty URL returns proxy.Direct (no proxy). +func buildProxyDialer(rawProxyURL string) (proxy.Dialer, error) { + proxyURL := strings.TrimSpace(rawProxyURL) + if proxyURL == "" { + return proxy.Direct, nil + } + + parsedURL, errParse := url.Parse(proxyURL) + if errParse != nil { + return nil, fmt.Errorf("failed to parse proxy URL %q: %w", rawProxyURL, errParse) + } + + switch parsedURL.Scheme { + case "socks5", "socks5h": + proxyDialer, errDialer := proxy.FromURL(parsedURL, proxy.Direct) + if errDialer != nil { + return nil, fmt.Errorf("failed to create SOCKS5 dialer for %q: %w", rawProxyURL, errDialer) + } + return proxyDialer, nil + case "http", "https": + if parsedURL.Host == "" { + return nil, fmt.Errorf("failed to parse proxy URL %q: missing host", rawProxyURL) + } + if parsedURL.Port() == "" { + defaultPort := "80" + if parsedURL.Scheme == "https" { + defaultPort = "443" + } + parsedURL.Host = net.JoinHostPort(parsedURL.Hostname(), defaultPort) + } + proxyURLCopy := *parsedURL + return dialerFunc(func(network, addr string) (net.Conn, error) { + return dialHTTPConnectProxy(&proxyURLCopy, network, addr) + }), nil + default: + return nil, fmt.Errorf("failed to create proxy dialer for %q: unsupported scheme %q", rawProxyURL, parsedURL.Scheme) + } +} + +// dialHTTPConnectProxy establishes a TCP connection through an HTTP(S) proxy +// using the CONNECT method, returning the tunneled connection. +func dialHTTPConnectProxy(proxyURL *url.URL, network, addr string) (net.Conn, error) { + if network != "tcp" { + return nil, fmt.Errorf("failed to dial via HTTP proxy: CONNECT only supports tcp, got %q", network) + } + + proxyConn, errDial := net.Dial("tcp", proxyURL.Host) + if errDial != nil { + return nil, fmt.Errorf("failed to dial proxy %q: %w", proxyURL.Host, errDial) + } + + if proxyURL.Scheme == "https" { + tlsConn := cryptotls.Client(proxyConn, &cryptotls.Config{ServerName: proxyURL.Hostname()}) + if errHandshake := tlsConn.Handshake(); errHandshake != nil { + _ = proxyConn.Close() + return nil, fmt.Errorf("failed to TLS-handshake with proxy %q: %w", proxyURL.Host, errHandshake) + } + proxyConn = tlsConn + } + + tunneledConn, errConnect := establishHTTPConnectTunnel(proxyConn, proxyURL, addr) + if errConnect != nil { + _ = proxyConn.Close() + return nil, errConnect + } + + return tunneledConn, nil +} + +// establishHTTPConnectTunnel sends an HTTP CONNECT request through proxyConn +// and returns the tunneled connection on success. +func establishHTTPConnectTunnel(proxyConn net.Conn, proxyURL *url.URL, addr string) (net.Conn, error) { + var reqBuf strings.Builder + reqBuf.WriteString("CONNECT ") + reqBuf.WriteString(addr) + reqBuf.WriteString(" HTTP/1.1\r\nHost: ") + reqBuf.WriteString(addr) + reqBuf.WriteString("\r\n") + + if proxyURL.User != nil { + username := proxyURL.User.Username() + password, _ := proxyURL.User.Password() + credentials := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + reqBuf.WriteString("Proxy-Authorization: Basic ") + reqBuf.WriteString(credentials) + reqBuf.WriteString("\r\n") + } + + reqBuf.WriteString("\r\n") + + if _, errWrite := io.WriteString(proxyConn, reqBuf.String()); errWrite != nil { + return nil, fmt.Errorf("failed to send CONNECT to proxy %q: %w", proxyURL.Host, errWrite) + } + + reader := bufio.NewReader(proxyConn) + resp, errResponse := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect}) + if errResponse != nil { + return nil, fmt.Errorf("failed to read CONNECT response from proxy %q: %w", proxyURL.Host, errResponse) + } + + if resp.StatusCode != http.StatusOK { + body, errBody := io.ReadAll(io.LimitReader(resp.Body, 4<<10)) + _ = resp.Body.Close() + if errBody != nil { + return nil, fmt.Errorf("proxy CONNECT to %q failed with status %s (failed to read error body: %v)", proxyURL.Host, resp.Status, errBody) + } + message := strings.TrimSpace(string(body)) + if message != "" { + return nil, fmt.Errorf("proxy CONNECT to %q failed with status %s: %s", proxyURL.Host, resp.Status, message) + } + return nil, fmt.Errorf("proxy CONNECT to %q failed with status %s", proxyURL.Host, resp.Status) + } + // resp.Body is intentionally not closed here: for CONNECT 200 Go sets it + // to http.NoBody, and the tunnel data follows in `reader`. + return &bufferedConn{Conn: proxyConn, reader: reader}, nil +} diff --git a/internal/auth/claude/proxy_dialer_test.go b/internal/auth/claude/proxy_dialer_test.go new file mode 100644 index 0000000000..ee2f4b5be9 --- /dev/null +++ b/internal/auth/claude/proxy_dialer_test.go @@ -0,0 +1,325 @@ +package claude + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" + "strings" + "testing" + "time" + + "golang.org/x/net/proxy" +) + +func TestBuildProxyDialerEmptyUsesDirect(t *testing.T) { + t.Parallel() + + dialer, err := buildProxyDialer("") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dialer != proxy.Direct { + t.Fatalf("expected proxy.Direct, got %T", dialer) + } +} + +func TestBuildProxyDialerWhitespaceUsesDirect(t *testing.T) { + t.Parallel() + + dialer, err := buildProxyDialer(" ") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dialer != proxy.Direct { + t.Fatalf("expected proxy.Direct, got %T", dialer) + } +} + +func TestBuildProxyDialerRejectsUnsupportedScheme(t *testing.T) { + t.Parallel() + + _, err := buildProxyDialer("ftp://proxy.example.com:21") + if err == nil { + t.Fatal("expected error for unsupported scheme, got nil") + } + if !strings.Contains(err.Error(), "unsupported scheme") { + t.Fatalf("error = %q, want substring 'unsupported scheme'", err) + } +} + +func TestBuildProxyDialerRejectsMissingHost(t *testing.T) { + t.Parallel() + + _, err := buildProxyDialer("http://") + if err == nil { + t.Fatal("expected error for missing host, got nil") + } +} + +func TestBuildProxyDialerAcceptsSocks5(t *testing.T) { + t.Parallel() + + dialer, err := buildProxyDialer("socks5://proxy.example.com:1080") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dialer == nil { + t.Fatal("expected dialer, got nil") + } + if dialer == proxy.Direct { + t.Fatal("expected proxy dialer, got proxy.Direct") + } +} + +func TestBuildProxyDialerAcceptsHTTPSProxy(t *testing.T) { + t.Parallel() + + dialer, err := buildProxyDialer("https://proxy.example.com:8443") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dialer == nil { + t.Fatal("expected dialer, got nil") + } +} + +func TestBuildProxyDialerDefaultPort(t *testing.T) { + t.Parallel() + + dialer, err := buildProxyDialer("http://proxy.example.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dialer == nil { + t.Fatal("expected dialer, got nil") + } +} + +func TestHTTPProxyConnectTunnel(t *testing.T) { + t.Parallel() + + targetListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen target: %v", err) + } + defer func() { _ = targetListener.Close() }() + + targetPayload := make(chan string, 1) + targetErr := make(chan error, 1) + go func() { + conn, errAccept := targetListener.Accept() + if errAccept != nil { + targetErr <- errAccept + return + } + defer func() { _ = conn.Close() }() + buf := make([]byte, 4) + if _, errRead := io.ReadFull(conn, buf); errRead != nil { + targetErr <- errRead + return + } + targetPayload <- string(buf) + _, _ = conn.Write([]byte("pong")) + targetErr <- nil + }() + + proxyListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen proxy: %v", err) + } + defer func() { _ = proxyListener.Close() }() + + connectLine := make(chan string, 1) + proxyErr := make(chan error, 1) + go serveConnectProxy(t, proxyListener, targetListener.Addr().String(), connectLine, proxyErr) + + dialer, err := buildProxyDialer("http://" + proxyListener.Addr().String()) + if err != nil { + t.Fatalf("buildProxyDialer: %v", err) + } + + conn, err := dialer.Dial("tcp", targetListener.Addr().String()) + if err != nil { + t.Fatalf("dial through proxy: %v", err) + } + defer func() { _ = conn.Close() }() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + + if _, err := conn.Write([]byte("ping")); err != nil { + t.Fatalf("write through tunnel: %v", err) + } + + resp := make([]byte, 4) + if _, err := io.ReadFull(conn, resp); err != nil { + t.Fatalf("read through tunnel: %v", err) + } + if string(resp) != "pong" { + t.Fatalf("response = %q, want %q", string(resp), "pong") + } + + wantLine := "CONNECT " + targetListener.Addr().String() + " HTTP/1.1" + if got := <-connectLine; got != wantLine { + t.Fatalf("CONNECT request = %q, want %q", got, wantLine) + } + if got := <-targetPayload; got != "ping" { + t.Fatalf("target payload = %q, want %q", got, "ping") + } + if err := <-proxyErr; err != nil { + t.Fatalf("proxy server: %v", err) + } + if err := <-targetErr; err != nil { + t.Fatalf("target server: %v", err) + } +} + +func TestHTTPProxyConnectTunnelWithAuth(t *testing.T) { + t.Parallel() + + proxyListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer func() { _ = proxyListener.Close() }() + + authHeader := make(chan string, 1) + proxyErr := make(chan error, 1) + go func() { + conn, errAccept := proxyListener.Accept() + if errAccept != nil { + proxyErr <- errAccept + return + } + defer func() { _ = conn.Close() }() + reader := bufio.NewReader(conn) + req, errRequest := http.ReadRequest(reader) + if errRequest != nil { + proxyErr <- errRequest + return + } + _ = req.Body.Close() + authHeader <- req.Header.Get("Proxy-Authorization") + _, _ = io.WriteString(conn, "HTTP/1.1 200 OK\r\n\r\n") + proxyErr <- nil + }() + + dialer, err := buildProxyDialer("http://user:pass@" + proxyListener.Addr().String()) + if err != nil { + t.Fatalf("buildProxyDialer: %v", err) + } + + conn, err := dialer.Dial("tcp", "example.com:443") + if err != nil { + t.Fatalf("dial: %v", err) + } + _ = conn.Close() + + got := <-authHeader + if got == "" { + t.Fatal("expected Proxy-Authorization header, got empty") + } + if !strings.HasPrefix(got, "Basic ") { + t.Fatalf("auth = %q, want Basic prefix", got) + } + if err := <-proxyErr; err != nil { + t.Fatalf("proxy: %v", err) + } +} + +func TestHTTPProxyConnectRejectsNon200(t *testing.T) { + t.Parallel() + + proxyListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer func() { _ = proxyListener.Close() }() + + go func() { + conn, errAccept := proxyListener.Accept() + if errAccept != nil { + return + } + defer func() { _ = conn.Close() }() + reader := bufio.NewReader(conn) + req, errRequest := http.ReadRequest(reader) + if errRequest != nil { + return + } + _ = req.Body.Close() + _, _ = io.WriteString(conn, "HTTP/1.1 407 Proxy Authentication Required\r\nContent-Length: 12\r\n\r\nUnauthorized") + }() + + dialer, err := buildProxyDialer("http://" + proxyListener.Addr().String()) + if err != nil { + t.Fatalf("buildProxyDialer: %v", err) + } + + _, err = dialer.Dial("tcp", "example.com:443") + if err == nil { + t.Fatal("expected error for 407, got nil") + } + if !strings.Contains(err.Error(), "407") { + t.Fatalf("error = %q, want substring '407'", err) + } +} + +func serveConnectProxy(t *testing.T, proxyListener net.Listener, targetAddr string, connectLine chan<- string, errCh chan<- error) { + t.Helper() + conn, err := proxyListener.Accept() + if err != nil { + errCh <- err + return + } + defer func() { _ = conn.Close() }() + + reader := bufio.NewReader(conn) + req, errRequest := http.ReadRequest(reader) + if errRequest != nil { + errCh <- errRequest + return + } + _ = req.Body.Close() + + connectLine <- req.Method + " " + req.RequestURI + " " + req.Proto + + if req.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected method %q", req.Method) + return + } + + targetConn, err := net.Dial("tcp", targetAddr) + if err != nil { + errCh <- err + return + } + defer func() { _ = targetConn.Close() }() + + if _, err := io.WriteString(conn, "HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil { + errCh <- err + return + } + + buf := make([]byte, 4) + if _, err := io.ReadFull(reader, buf); err != nil { + errCh <- err + return + } + if _, err := targetConn.Write(buf); err != nil { + errCh <- err + return + } + + resp := make([]byte, 4) + if _, err := io.ReadFull(targetConn, resp); err != nil { + errCh <- err + return + } + if _, err := conn.Write(resp); err != nil { + errCh <- err + return + } + + errCh <- nil +}