diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ba36c013bf5..69f44066d8b 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -217,7 +217,7 @@ jobs: - arch: "386" raceFlag: "" - arch: "amd64" - raceFlag: "-race" + raceFlag: "-race -v" runs-on: ubuntu-22.04 steps: - name: Install Go @@ -258,6 +258,7 @@ jobs: run: | CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ go test ${{ matrix.raceFlag }} \ + -tags devcert \ -exec 'sudo' \ -timeout 10m ./relay/... ./shared/relay/... diff --git a/client/internal/connect.go b/client/internal/connect.go index bb7c2b38b0e..29cb036f445 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -253,7 +253,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return wrapErr(err) } - relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU) + relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), &relayClient.ManagerOpts{ + MTU: engineConfig.MTU, + }) c.statusRecorder.SetRelayMgr(relayManager) if len(relayURLs) > 0 { if token != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d15a07f9d27..7fc1856b032 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -14,7 +14,6 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" - "github.com/netbirdio/netbird/client/internal/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,7 +24,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -227,7 +229,7 @@ func TestEngine_SSH(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU}) engine := NewEngine( ctx, cancel, &signal.MockClient{}, @@ -373,7 +375,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU}) engine := NewEngine( ctx, cancel, &signal.MockClient{}, @@ -600,7 +602,7 @@ func TestEngine_Sync(t *testing.T) { } return nil } - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU}) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ WgIfaceName: "utun103", WgAddr: "100.64.0.1/24", @@ -765,7 +767,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n) - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU}) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, @@ -967,7 +969,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n) - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU}) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, @@ -1499,7 +1501,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin MTU: iface.DefaultMTU, } - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU}) e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil e.ctx = ctx return e, err diff --git a/relay/server/relay.go b/relay/server/relay.go index d866849379c..771eea4fdd5 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -132,6 +132,7 @@ func (r *Relay) Accept(conn net.Conn) { storeTime := time.Now() if isReconnection := r.store.AddPeer(peer); isReconnection { r.metrics.RecordPeerReconnection() + r.notifier.PeerWentOffline(peer.ID()) } r.notifier.PeerCameOnline(peer.ID()) diff --git a/shared/relay/client/client_test.go b/shared/relay/client/client_test.go index 8fe5f04f444..09caac6f478 100644 --- a/shared/relay/client/client_test.go +++ b/shared/relay/client/client_test.go @@ -11,11 +11,11 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay/auth/allow" "github.com/netbirdio/netbird/shared/relay/auth/hmac" + "github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/util" - - "github.com/netbirdio/netbird/relay/server" ) var ( @@ -312,7 +312,7 @@ func TestBindToUnavailabePeer(t *testing.T) { clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { - t.Errorf("failed to connect to server: %s", err) + t.Fatalf("failed to connect to server: %s", err) } _, err = clientAlice.OpenConn(ctx, "bob") if err == nil { @@ -364,7 +364,7 @@ func TestBindReconnect(t *testing.T) { clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) err = clientBob.Connect(ctx) if err != nil { - t.Errorf("failed to connect to server: %s", err) + t.Fatalf("failed to connect to server: %s", err) } _, err = clientAlice.OpenConn(ctx, "bob") @@ -374,7 +374,7 @@ func TestBindReconnect(t *testing.T) { chBob, err := clientBob.OpenConn(ctx, "alice") if err != nil { - t.Errorf("failed to bind channel: %s", err) + t.Fatalf("failed to bind channel: %s", err) } log.Infof("closing client Alice") @@ -386,12 +386,12 @@ func TestBindReconnect(t *testing.T) { clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { - t.Errorf("failed to connect to server: %s", err) + t.Fatalf("failed to connect to server: %s", err) } chAlice, err := clientAlice.OpenConn(ctx, "bob") if err != nil { - t.Errorf("failed to bind channel: %s", err) + t.Fatalf("failed to bind channel: %s", err) } testString := "hello alice, I am bob" @@ -402,7 +402,7 @@ func TestBindReconnect(t *testing.T) { chBob, err = clientBob.OpenConn(ctx, "alice") if err != nil { - t.Errorf("failed to bind channel: %s", err) + t.Fatalf("failed to bind channel: %s", err) } _, err = chBob.Write([]byte(testString)) @@ -427,6 +427,105 @@ func TestBindReconnect(t *testing.T) { } } +func TestBindReconnectRace(t *testing.T) { + ctx := context.Background() + + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(serverCfg) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + log.Infof("closing server") + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + err = clientBob.Connect(ctx) + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer clientBob.Close() + + // Run the reconnection scenario multiple times to expose the race + failures := 0 + iterations := 1000 + + for i := 0; i < iterations; i++ { + log.Infof("Iteration %d/%d", i+1, iterations) + + // Alice connects + clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + err = clientAlice.Connect(ctx) + if err != nil { + t.Fatalf("iteration %d: failed to connect alice: %s", i, err) + } + + // Bob opens connection to Alice + _, err = clientBob.OpenConn(ctx, "alice") + if err != nil { + t.Fatalf("iteration %d: failed to open conn from bob: %s", i, err) + } + + // Close Alice immediately + err = clientAlice.Close() + if err != nil { + t.Errorf("iteration %d: failed to close alice: %s", i, err) + } + + // Reconnect Alice immediately (this is where the race occurs) + clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + err = clientAlice.Connect(ctx) + if err != nil { + t.Fatalf("iteration %d: failed to reconnect alice: %s", i, err) + } + + // Bob tries to open a new connection to the reconnected Alice + // Without the fix, this will sometimes fail with "connection already exists" + // because Bob still has the old connection in its map + _, err = clientBob.OpenConn(ctx, "alice") + if err != nil { + log.Errorf("iteration %d: RACE DETECTED - failed to open new conn after reconnect: %s", i, err) + failures++ + } + + // Clean up + clientAlice.Close() + + // Close Bob's connection to Alice to prepare for next iteration + clientBob.mu.Lock() + aliceID := messages.HashID("alice") + if container, ok := clientBob.conns[aliceID]; ok { + container.close() + delete(clientBob.conns, aliceID) + } + clientBob.mu.Unlock() + } + + if failures > 0 { + t.Errorf("Race condition detected in %d out of %d iterations (%.1f%%)", + failures, iterations, float64(failures)/float64(iterations)*100) + } else { + log.Infof("No race detected in %d iterations (fix is working or race didn't trigger)", iterations) + } +} + func TestCloseConn(t *testing.T) { ctx := context.Background() @@ -459,18 +558,18 @@ func TestCloseConn(t *testing.T) { bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) err = bob.Connect(ctx) if err != nil { - t.Errorf("failed to connect to server: %s", err) + t.Fatalf("failed to connect to server: %s", err) } clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { - t.Errorf("failed to connect to server: %s", err) + t.Fatalf("failed to connect to server: %s", err) } conn, err := clientAlice.OpenConn(ctx, "bob") if err != nil { - t.Errorf("failed to bind channel: %s", err) + t.Fatalf("failed to bind channel: %s", err) } log.Infof("closing connection") @@ -532,7 +631,7 @@ func TestCloseRelayConn(t *testing.T) { conn, err := clientAlice.OpenConn(ctx, "bob") if err != nil { - t.Errorf("failed to bind channel: %s", err) + t.Fatalf("failed to bind channel: %s", err) } _ = clientAlice.relayConn.Close() diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index 6220e7f6b06..8ed2a00b9a4 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -14,10 +14,15 @@ import ( relayAuth "github.com/netbirdio/netbird/shared/relay/auth/hmac" ) -var ( - relayCleanupInterval = 60 * time.Second - keepUnusedServerTime = 5 * time.Second +const ( + defaultRelayCleanupInterval = 60 * time.Second + defaultKeepUnusedServerTime = 5 * time.Second + defaultMTU = 1280 + minMTU = 1280 + maxMTU = 65535 +) +var ( ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") ) @@ -64,14 +69,55 @@ type Manager struct { onReconnectedListenerFn func() listenerLock sync.Mutex - mtu uint16 + cleanupInterval time.Duration + unusedServerTime time.Duration + mtu uint16 +} + +// ManagerOpts contains optional configuration for Manager +type ManagerOpts struct { + // CleanupInterval is the interval for cleaning up unused relay connections. + // If zero, defaults to defaultRelayCleanupInterval. + CleanupInterval time.Duration + // UnusedServerTime is the time to wait before closing unused relay connections. + // If zero, defaults to defaultKeepUnusedServerTime. + UnusedServerTime time.Duration + // MTU is the maximum transmission unit for relay connections. + // If zero, defaults to defaultMTU (1280). + // Must be between minMTU (1280) and maxMTU (65535). + MTU uint16 } // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. -func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager { +// Optional parameters can be configured using ManagerOpts. Pass nil to use default values. +func NewManager(ctx context.Context, serverURLs []string, peerID string, opts *ManagerOpts) *Manager { tokenStore := &relayAuth.TokenStore{} + cleanupInterval := defaultRelayCleanupInterval + unusedServerTime := defaultKeepUnusedServerTime + mtu := uint16(defaultMTU) + + if opts != nil { + if opts.CleanupInterval > 0 { + cleanupInterval = opts.CleanupInterval + } + if opts.UnusedServerTime > 0 { + unusedServerTime = opts.UnusedServerTime + } + if opts.MTU > 0 { + if opts.MTU < minMTU { + log.Warnf("MTU %d is below minimum %d, using minimum", opts.MTU, minMTU) + mtu = minMTU + } else if opts.MTU > maxMTU { + log.Warnf("MTU %d exceeds maximum %d, using maximum", opts.MTU, maxMTU) + mtu = maxMTU + } else { + mtu = opts.MTU + } + } + } + m := &Manager{ ctx: ctx, peerID: peerID, @@ -85,6 +131,8 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), + cleanupInterval: cleanupInterval, + unusedServerTime: unusedServerTime, } m.serverPicker.ServerURLs.Store(serverURLs) m.reconnectGuard = NewGuard(m.serverPicker) @@ -334,7 +382,7 @@ func (m *Manager) isForeignServer(address string) (bool, error) { } func (m *Manager) startCleanupLoop() { - ticker := time.NewTicker(relayCleanupInterval) + ticker := time.NewTicker(m.cleanupInterval) defer ticker.Stop() for { select { @@ -359,7 +407,7 @@ func (m *Manager) cleanUpUnusedRelays() { continue } - if time.Since(rt.created) <= keepUnusedServerTime { + if time.Since(rt.created) <= m.unusedServerTime { rt.Unlock() continue } diff --git a/shared/relay/client/manager_test.go b/shared/relay/client/manager_test.go index f00b3570709..6527402344b 100644 --- a/shared/relay/client/manager_test.go +++ b/shared/relay/client/manager_test.go @@ -16,7 +16,7 @@ import ( func TestEmptyURL(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mgr := NewManager(ctx, nil, "alice", iface.DefaultMTU) + mgr := NewManager(ctx, nil, "alice", &ManagerOpts{MTU: iface.DefaultMTU}) err := mgr.Serve() if err == nil { t.Errorf("expected error, got nil") @@ -91,12 +91,12 @@ func TestForeignConn(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice", iface.DefaultMTU) + clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice", &ManagerOpts{MTU: iface.DefaultMTU}) if err := clientAlice.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - clientBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU) + clientBob := NewManager(mCtx, toURL(srvCfg2), "bob", &ManagerOpts{MTU: iface.DefaultMTU}) if err := clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } @@ -198,12 +198,12 @@ func TestForeginConnClose(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU) + mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob", &ManagerOpts{MTU: iface.DefaultMTU}) if err := mgrBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - mgr := NewManager(mCtx, toURL(srvCfg1), "alice", iface.DefaultMTU) + mgr := NewManager(mCtx, toURL(srvCfg1), "alice", &ManagerOpts{MTU: iface.DefaultMTU}) err = mgr.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -221,8 +221,8 @@ func TestForeginConnClose(t *testing.T) { func TestForeignAutoClose(t *testing.T) { ctx := context.Background() - relayCleanupInterval = 1 * time.Second - keepUnusedServerTime = 2 * time.Second + testCleanupInterval := 1 * time.Second + testUnusedServerTime := 2 * time.Second srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", @@ -283,7 +283,11 @@ func TestForeignAutoClose(t *testing.T) { t.Log("connect to server 1.") mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgr := NewManager(mCtx, toURL(srvCfg1), idAlice, iface.DefaultMTU) + mgr := NewManager(mCtx, toURL(srvCfg1), idAlice, &ManagerOpts{ + MTU: iface.DefaultMTU, + CleanupInterval: testCleanupInterval, + UnusedServerTime: testUnusedServerTime, + }) err = mgr.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -310,7 +314,7 @@ func TestForeignAutoClose(t *testing.T) { } // Wait for cleanup to happen - timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second + timeout := testCleanupInterval + testUnusedServerTime + 2*time.Second t.Logf("waiting for relay cleanup: %s", timeout) select { @@ -354,13 +358,13 @@ func TestAutoReconnect(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientBob := NewManager(mCtx, toURL(srvCfg), "bob", iface.DefaultMTU) + clientBob := NewManager(mCtx, toURL(srvCfg), "bob", &ManagerOpts{MTU: iface.DefaultMTU}) err = clientBob.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) } - clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU) + clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", &ManagerOpts{MTU: iface.DefaultMTU}) err = clientAlice.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -429,12 +433,12 @@ func TestNotifierDoubleAdd(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob", iface.DefaultMTU) + clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob", &ManagerOpts{MTU: iface.DefaultMTU}) if err = clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice", iface.DefaultMTU) + clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice", &ManagerOpts{MTU: iface.DefaultMTU}) if err = clientAlice.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) }