Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ const (
stateDirPerms = 0700
)

// Dialer abstracts the remote connection mechanism.
// tsnet.Server satisfies this interface without an adapter.
type Dialer interface {
Dial(ctx context.Context, network, addr string) (net.Conn, error)
}

// Config holds the bridge configuration.
type Config struct {
LocalAddr string
Expand Down Expand Up @@ -623,7 +629,7 @@ func printBanner(cfg Config) {
fmt.Println()
}

func acceptLoop(listener net.Listener, server *tsnet.Server, cfg Config) error {
func acceptLoop(listener net.Listener, dialer Dialer, cfg Config) error {
backoff := backoffMin

for {
Expand Down Expand Up @@ -654,7 +660,7 @@ func acceptLoop(listener net.Listener, server *tsnet.Server, cfg Config) error {
continue
}

go handleConn(conn, server, cfg)
go handleConn(conn, dialer, cfg)
}
}

Expand All @@ -666,7 +672,7 @@ var bufferPool = sync.Pool{
},
}

func handleConn(client net.Conn, server *tsnet.Server, cfg Config) {
func handleConn(client net.Conn, dialer Dialer, cfg Config) {
// Track metrics
atomic.AddInt64(&metrics.ActiveConnections, 1)
atomic.AddInt64(&metrics.TotalConnections, 1)
Expand All @@ -689,7 +695,7 @@ func handleConn(client net.Conn, server *tsnet.Server, cfg Config) {
ctx, cancel := context.WithTimeout(context.Background(), cfg.ConnectTimeout)
defer cancel()

remote, err := server.Dial(ctx, "tcp", cfg.Target)
remote, err := dialer.Dial(ctx, "tcp", cfg.Target)
if err != nil {
atomic.AddInt64(&metrics.TotalErrors, 1)
logger.Error("dial failed", "client", addr, "target", cfg.Target, "error", err)
Expand Down Expand Up @@ -742,8 +748,14 @@ func proxyConnections(client, remote net.Conn, addr string) (tx, rx int64) {
}
}

go copyConn(client, remote, "rx", &rx)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
copyConn(client, remote, "rx", &rx)
}()
copyConn(remote, client, "tx", &tx)
wg.Wait()

return tx, rx
}
Expand Down Expand Up @@ -780,5 +792,8 @@ func isExpectedCloseError(err error) bool {
if strings.Contains(errStr, "forcibly closed by the remote host") {
return true
}
if strings.Contains(errStr, "closed pipe") {
return true
}
return false
}
1 change: 1 addition & 0 deletions main_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ func TestIsExpectedCloseError(t *testing.T) {
{"closed network", errors.New("use of closed network connection"), true},
{"connection reset", errors.New("connection reset by peer"), true},
{"windows wsarecv forced close", errors.New("wsarecv: An existing connection was forcibly closed by the remote host"), true},
{"closed pipe", errors.New("io: read/write on closed pipe"), true},
}

for _, tt := range tests {
Expand Down
192 changes: 192 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,205 @@ import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"net"
"os"
"strings"
"sync/atomic"
"testing"
"time"
)

// mockDialer implements Dialer for testing without tsnet.
type mockDialer struct {
dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
}

func (m *mockDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
return m.dialFunc(ctx, network, addr)
}

// TestDialerInterfaceSatisfaction verifies that mockDialer satisfies the Dialer interface.
// This is a compile-time check: if Dialer doesn't exist or has a different signature, this fails.
var _ Dialer = (*mockDialer)(nil)

func TestHandleConnWithDialer(t *testing.T) {
initLogger(Config{LogFormat: "text"})

tests := []struct {
name string
dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
wantErrors int64
wantTotalConn int64
}{
{
name: "successful proxy via dialer",
dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) {
// Return a pipe that immediately closes (simulates short-lived connection)
server, client := net.Pipe()
go func() {
// Echo one read then close
buf := make([]byte, 1024)
n, _ := server.Read(buf)
if n > 0 {
_, _ = server.Write(buf[:n])
}
server.Close()
}()
return client, nil
},
wantErrors: 0,
wantTotalConn: 1,
},
{
name: "dial failure increments errors",
dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("connection refused")
},
wantErrors: 1,
wantTotalConn: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset metrics
oldMetrics := metrics
metrics = Metrics{}
defer func() { metrics = oldMetrics }()

dialer := &mockDialer{dialFunc: tt.dialFunc}
cfg := Config{
Target: "100.64.0.1:3389",
ConnectTimeout: 5 * time.Second,
}

// Create a client connection via pipe
clientConn, proxyConn := net.Pipe()
defer clientConn.Close()

// Run handleConn in goroutine (it blocks until proxy finishes)
done := make(chan struct{})
go func() {
handleConn(proxyConn, dialer, cfg)
close(done)
}()

if tt.wantErrors == 0 {
// Send data through the proxy
_, _ = clientConn.Write([]byte("HELLO"))

buf := make([]byte, 1024)
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := clientConn.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
t.Fatalf("read from proxy failed: %v", err)
}
if n > 0 && string(buf[:n]) != "HELLO" {
t.Errorf("expected echo HELLO, got %q", buf[:n])
}
}

// Close client side to let handleConn finish
clientConn.Close()

select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("handleConn did not finish in time")
}

gotErrors := atomic.LoadInt64(&metrics.TotalErrors)
if gotErrors != tt.wantErrors {
t.Errorf("TotalErrors = %d, want %d", gotErrors, tt.wantErrors)
}

gotTotal := atomic.LoadInt64(&metrics.TotalConnections)
if gotTotal != tt.wantTotalConn {
t.Errorf("TotalConnections = %d, want %d", gotTotal, tt.wantTotalConn)
}
})
}
}

func TestAcceptLoopWithDialer(t *testing.T) {
initLogger(Config{LogFormat: "text"})

// Snapshot metrics before test to check delta after
connsBefore := atomic.LoadInt64(&metrics.TotalConnections)

// Mock dialer that echoes data
dialer := &mockDialer{
dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) {
server, client := net.Pipe()
go func() {
defer server.Close()
buf := make([]byte, 1024)
n, _ := server.Read(buf)
if n > 0 {
_, _ = server.Write(buf[:n])
}
}()
return client, nil
},
}

cfg := Config{
Target: "100.64.0.1:3389",
ConnectTimeout: 5 * time.Second,
MaxConnections: 1000,
}

listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen failed: %v", err)
}

// Run accept loop in background
loopDone := make(chan error, 1)
go func() {
loopDone <- acceptLoop(listener, dialer, cfg)
}()

// Connect a client through the accept loop
conn, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
t.Fatalf("dial failed: %v", err)
}

_, _ = conn.Write([]byte("TEST"))

buf := make([]byte, 1024)
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := conn.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
t.Fatalf("read failed: %v", err)
}
if string(buf[:n]) != "TEST" {
t.Errorf("expected TEST, got %q", buf[:n])
}
conn.Close()

// Close listener to stop accept loop
listener.Close()

select {
case err := <-loopDone:
if err != nil {
t.Errorf("acceptLoop returned error: %v", err)
}
case <-time.After(3 * time.Second):
t.Fatal("acceptLoop did not stop")
}

// Check that at least one connection was handled (use atomic reads, no struct reset)
connsAfter := atomic.LoadInt64(&metrics.TotalConnections)
if connsAfter <= connsBefore {
t.Errorf("expected TotalConnections to increase, before=%d after=%d", connsBefore, connsAfter)
}
}

func TestLoadConfig(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading