diff --git a/main.go b/main.go index c73054a..97de070 100644 --- a/main.go +++ b/main.go @@ -67,6 +67,12 @@ const ( stateDirPerms = 0700 ) +// Dialer abstracts the remote connection mechanism. +// tsnet.Server satisfies this interface without an adapter. +type Dialer interface { + Dial(ctx context.Context, network, addr string) (net.Conn, error) +} + // Config holds the bridge configuration. type Config struct { LocalAddr string @@ -623,7 +629,7 @@ func printBanner(cfg Config) { fmt.Println() } -func acceptLoop(listener net.Listener, server *tsnet.Server, cfg Config) error { +func acceptLoop(listener net.Listener, dialer Dialer, cfg Config) error { backoff := backoffMin for { @@ -654,7 +660,7 @@ func acceptLoop(listener net.Listener, server *tsnet.Server, cfg Config) error { continue } - go handleConn(conn, server, cfg) + go handleConn(conn, dialer, cfg) } } @@ -666,7 +672,7 @@ var bufferPool = sync.Pool{ }, } -func handleConn(client net.Conn, server *tsnet.Server, cfg Config) { +func handleConn(client net.Conn, dialer Dialer, cfg Config) { // Track metrics atomic.AddInt64(&metrics.ActiveConnections, 1) atomic.AddInt64(&metrics.TotalConnections, 1) @@ -689,7 +695,7 @@ func handleConn(client net.Conn, server *tsnet.Server, cfg Config) { ctx, cancel := context.WithTimeout(context.Background(), cfg.ConnectTimeout) defer cancel() - remote, err := server.Dial(ctx, "tcp", cfg.Target) + remote, err := dialer.Dial(ctx, "tcp", cfg.Target) if err != nil { atomic.AddInt64(&metrics.TotalErrors, 1) logger.Error("dial failed", "client", addr, "target", cfg.Target, "error", err) @@ -742,8 +748,14 @@ func proxyConnections(client, remote net.Conn, addr string) (tx, rx int64) { } } - go copyConn(client, remote, "rx", &rx) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + copyConn(client, remote, "rx", &rx) + }() copyConn(remote, client, "tx", &tx) + wg.Wait() return tx, rx } @@ -780,5 +792,8 @@ func isExpectedCloseError(err error) bool { if strings.Contains(errStr, "forcibly closed by the remote host") { return true } + if strings.Contains(errStr, "closed pipe") { + return true + } return false } diff --git a/main_integration_test.go b/main_integration_test.go index bd5838a..49aaa40 100644 --- a/main_integration_test.go +++ b/main_integration_test.go @@ -350,6 +350,7 @@ func TestIsExpectedCloseError(t *testing.T) { {"closed network", errors.New("use of closed network connection"), true}, {"connection reset", errors.New("connection reset by peer"), true}, {"windows wsarecv forced close", errors.New("wsarecv: An existing connection was forcibly closed by the remote host"), true}, + {"closed pipe", errors.New("io: read/write on closed pipe"), true}, } for _, tt := range tests { diff --git a/main_test.go b/main_test.go index 403cb41..20c5f5e 100644 --- a/main_test.go +++ b/main_test.go @@ -4,13 +4,205 @@ import ( "context" "errors" "fmt" + "io" "log/slog" + "net" "os" "strings" + "sync/atomic" "testing" "time" ) +// mockDialer implements Dialer for testing without tsnet. +type mockDialer struct { + dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +} + +func (m *mockDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { + return m.dialFunc(ctx, network, addr) +} + +// TestDialerInterfaceSatisfaction verifies that mockDialer satisfies the Dialer interface. +// This is a compile-time check: if Dialer doesn't exist or has a different signature, this fails. +var _ Dialer = (*mockDialer)(nil) + +func TestHandleConnWithDialer(t *testing.T) { + initLogger(Config{LogFormat: "text"}) + + tests := []struct { + name string + dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + wantErrors int64 + wantTotalConn int64 + }{ + { + name: "successful proxy via dialer", + dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { + // Return a pipe that immediately closes (simulates short-lived connection) + server, client := net.Pipe() + go func() { + // Echo one read then close + buf := make([]byte, 1024) + n, _ := server.Read(buf) + if n > 0 { + _, _ = server.Write(buf[:n]) + } + server.Close() + }() + return client, nil + }, + wantErrors: 0, + wantTotalConn: 1, + }, + { + name: "dial failure increments errors", + dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("connection refused") + }, + wantErrors: 1, + wantTotalConn: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset metrics + oldMetrics := metrics + metrics = Metrics{} + defer func() { metrics = oldMetrics }() + + dialer := &mockDialer{dialFunc: tt.dialFunc} + cfg := Config{ + Target: "100.64.0.1:3389", + ConnectTimeout: 5 * time.Second, + } + + // Create a client connection via pipe + clientConn, proxyConn := net.Pipe() + defer clientConn.Close() + + // Run handleConn in goroutine (it blocks until proxy finishes) + done := make(chan struct{}) + go func() { + handleConn(proxyConn, dialer, cfg) + close(done) + }() + + if tt.wantErrors == 0 { + // Send data through the proxy + _, _ = clientConn.Write([]byte("HELLO")) + + buf := make([]byte, 1024) + _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := clientConn.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("read from proxy failed: %v", err) + } + if n > 0 && string(buf[:n]) != "HELLO" { + t.Errorf("expected echo HELLO, got %q", buf[:n]) + } + } + + // Close client side to let handleConn finish + clientConn.Close() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("handleConn did not finish in time") + } + + gotErrors := atomic.LoadInt64(&metrics.TotalErrors) + if gotErrors != tt.wantErrors { + t.Errorf("TotalErrors = %d, want %d", gotErrors, tt.wantErrors) + } + + gotTotal := atomic.LoadInt64(&metrics.TotalConnections) + if gotTotal != tt.wantTotalConn { + t.Errorf("TotalConnections = %d, want %d", gotTotal, tt.wantTotalConn) + } + }) + } +} + +func TestAcceptLoopWithDialer(t *testing.T) { + initLogger(Config{LogFormat: "text"}) + + // Snapshot metrics before test to check delta after + connsBefore := atomic.LoadInt64(&metrics.TotalConnections) + + // Mock dialer that echoes data + dialer := &mockDialer{ + dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { + server, client := net.Pipe() + go func() { + defer server.Close() + buf := make([]byte, 1024) + n, _ := server.Read(buf) + if n > 0 { + _, _ = server.Write(buf[:n]) + } + }() + return client, nil + }, + } + + cfg := Config{ + Target: "100.64.0.1:3389", + ConnectTimeout: 5 * time.Second, + MaxConnections: 1000, + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + + // Run accept loop in background + loopDone := make(chan error, 1) + go func() { + loopDone <- acceptLoop(listener, dialer, cfg) + }() + + // Connect a client through the accept loop + conn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatalf("dial failed: %v", err) + } + + _, _ = conn.Write([]byte("TEST")) + + buf := make([]byte, 1024) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("read failed: %v", err) + } + if string(buf[:n]) != "TEST" { + t.Errorf("expected TEST, got %q", buf[:n]) + } + conn.Close() + + // Close listener to stop accept loop + listener.Close() + + select { + case err := <-loopDone: + if err != nil { + t.Errorf("acceptLoop returned error: %v", err) + } + case <-time.After(3 * time.Second): + t.Fatal("acceptLoop did not stop") + } + + // Check that at least one connection was handled (use atomic reads, no struct reset) + connsAfter := atomic.LoadInt64(&metrics.TotalConnections) + if connsAfter <= connsBefore { + t.Errorf("expected TotalConnections to increase, before=%d after=%d", connsBefore, connsAfter) + } +} + func TestLoadConfig(t *testing.T) { tests := []struct { name string