diff --git a/agent.go b/agent.go index 665c26aa..92273ae5 100644 --- a/agent.go +++ b/agent.go @@ -11,6 +11,7 @@ import ( "math" "net" "net/netip" + "strconv" "strings" "sync" "sync/atomic" @@ -58,8 +59,10 @@ type Agent struct { tieBreaker uint64 lite bool - connectionState ConnectionState - gatheringState GatheringState + connectionState ConnectionState + gatheringState GatheringState + gatherGeneration uint64 + gatherEndSent bool mDNSMode MulticastDNSMode mDNSName string @@ -1052,28 +1055,39 @@ func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop a.requestConnectivityCheck() } -func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn) error { +func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn, gen uint64) error { if err := ctx.Err(); err != nil { return err } + cleanupCandidate := func(reason string) { + if err := cand.close(); err != nil { + a.log.Warnf("Failed to close %s candidate: %v", reason, err) + } + if err := candidateConn.Close(); err != nil { + a.log.Warnf("Failed to close %s candidate connection: %v", reason, err) + } + } + return a.loop.Run(ctx, func(context.Context) { + if a.gatherGeneration != gen { + a.log.Debugf("Ignoring candidate from different gather generation (a: %d c: %d)", a.gatherGeneration, gen) + cleanupCandidate("old") + + return + } + set := a.localCandidates[cand.NetworkType()] for _, candidate := range set { if candidate.Equal(cand) { a.log.Debugf("Ignore duplicate candidate: %s", cand) - if err := cand.close(); err != nil { - a.log.Warnf("Failed to close duplicate candidate: %v", err) - } - if err := candidateConn.Close(); err != nil { - a.log.Warnf("Failed to close duplicate candidate connection: %v", err) - } + cleanupCandidate("duplicate") return } } - a.setCandidateExtensions(cand) + a.setCandidateExtensions(cand, gen) cand.start(a, candidateConn, a.startedCh) set = append(set, cand) @@ -1093,7 +1107,7 @@ func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn }) } -func (a *Agent) setCandidateExtensions(cand Candidate) { +func (a *Agent) setCandidateExtensions(cand Candidate, candidateGeneration uint64) { err := cand.AddExtension(CandidateExtension{ Key: "ufrag", Value: a.localUfrag, @@ -1101,6 +1115,14 @@ func (a *Agent) setCandidateExtensions(cand Candidate) { if err != nil { a.log.Errorf("Failed to add ufrag extension to candidate: %v", err) } + + err = cand.AddExtension(CandidateExtension{ + Key: "generation", + Value: strconv.FormatUint(candidateGeneration, 10), + }) + if err != nil { + a.log.Errorf("Failed to add generation extension to candidate: %v", err) + } } // GetRemoteCandidates returns the remote candidates. @@ -1637,6 +1659,10 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop if a.gatheringState == GatheringStateGathering { a.gatherCandidateCancel() } + if a.gatheringState != GatheringStateNew { + a.setGatheringStateLocked(GatheringStateComplete, a.gatherGeneration) + } + a.bumpGatheringGenerationLocked() // Clear all agent needed to take back to fresh state a.removeUfragFromMux() @@ -1644,7 +1670,6 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop a.localPwd = pwd a.remoteUfrag = "" a.remotePwd = "" - a.gatheringState = GatheringStateNew a.checklist = make([]*CandidatePair, 0) a.pairsByID = make(map[uint64]*CandidatePair) a.pendingBindingRequests = make([]bindingRequest, 0) @@ -1664,14 +1689,10 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop return err } -func (a *Agent) setGatheringState(newState GatheringState) error { +func (a *Agent) setGatheringState(newState GatheringState, generation uint64) error { done := make(chan struct{}) if err := a.loop.Run(a.loop, func(context.Context) { - if a.gatheringState != newState && newState == GatheringStateComplete { - a.candidateNotifier.EnqueueCandidate(nil) - } - - a.gatheringState = newState + a.setGatheringStateLocked(newState, generation) close(done) }); err != nil { return err @@ -1682,6 +1703,30 @@ func (a *Agent) setGatheringState(newState GatheringState) error { return nil } +func (a *Agent) setGatheringStateLocked(newState GatheringState, generation uint64) { + if generation != a.gatherGeneration { + return + } + + prevState := a.gatheringState + if prevState == newState { + return + } + + if newState == GatheringStateComplete && !a.gatherEndSent { + a.candidateNotifier.EnqueueCandidate(nil) + a.gatherEndSent = true + } + + a.gatheringState = newState +} + +func (a *Agent) bumpGatheringGenerationLocked() { + a.gatherGeneration++ + a.gatherEndSent = false + a.gatheringState = GatheringStateNew +} + func (a *Agent) needsToCheckPriorityOnNominated() bool { return !a.lite || a.enableUseCandidateCheckPriority } diff --git a/agent_test.go b/agent_test.go index 6c569334..f67bd807 100644 --- a/agent_test.go +++ b/agent_test.go @@ -1385,9 +1385,7 @@ func TestAgentRestart(t *testing.T) { t.Run("Restart Both Sides", func(t *testing.T) { // Get all addresses of candidates concatenated - generateCandidateAddressStrings := func(candidates []Candidate, err error) (out string) { - require.NoError(t, err) - + generateCandidateAddressStrings := func(candidates []Candidate) (out string) { for _, c := range candidates { out += c.Address() + ":" out += strconv.Itoa(c.Port()) @@ -1396,14 +1394,31 @@ func TestAgentRestart(t *testing.T) { return } + candidateHasGeneration := func(generation uint64, candidate Candidate) { + genString := strconv.FormatUint(generation, 10) + ext, ok := candidate.GetExtension("generation") + + require.True(t, ok) + require.Equal(t, genString, ext.Value) + } + // Store the original candidates, confirm that after we reconnect we have new pairs connA, connB := pipe(t, &AgentConfig{ DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, }) defer closePipe(t, connA, connB) - connAFirstCandidates := generateCandidateAddressStrings(connA.agent.GetLocalCandidates()) - connBFirstCandidates := generateCandidateAddressStrings(connB.agent.GetLocalCandidates()) + + aFirstGeneration := connA.agent.gatherGeneration + bFirstGeneration := connB.agent.gatherGeneration + + connAFirstCandidates, err := connA.agent.GetLocalCandidates() + require.NoError(t, err) + connBFirstCandidates, err := connB.agent.GetLocalCandidates() + require.NoError(t, err) + + candidateHasGeneration(aFirstGeneration, connAFirstCandidates[0]) + candidateHasGeneration(bFirstGeneration, connBFirstCandidates[0]) aNotifier, aConnected := onConnected() require.NoError(t, connA.agent.OnConnectionStateChange(aNotifier)) @@ -1415,6 +1430,10 @@ func TestAgentRestart(t *testing.T) { require.NoError(t, connA.agent.Restart("", "")) require.NoError(t, connB.agent.Restart("", "")) + // Generation should change after Restart call + require.NotEqual(t, aFirstGeneration, connA.agent.gatherGeneration) + require.NotEqual(t, bFirstGeneration, connB.agent.gatherGeneration) + // Exchange Candidates and Credentials ufrag, pwd, err := connB.agent.GetLocalUserCredentials() require.NoError(t, err) @@ -1430,9 +1449,21 @@ func TestAgentRestart(t *testing.T) { <-aConnected <-bConnected + connASecondCandidates, err := connA.agent.GetLocalCandidates() + require.NoError(t, err) + connBSecondCandidates, err := connB.agent.GetLocalCandidates() + require.NoError(t, err) + + candidateHasGeneration(connA.agent.gatherGeneration, connASecondCandidates[0]) + candidateHasGeneration(connB.agent.gatherGeneration, connASecondCandidates[0]) + // Assert that we have new candidates each time - require.NotEqual(t, connAFirstCandidates, generateCandidateAddressStrings(connA.agent.GetLocalCandidates())) - require.NotEqual(t, connBFirstCandidates, generateCandidateAddressStrings(connB.agent.GetLocalCandidates())) + aFirstCandidatesString := generateCandidateAddressStrings(connAFirstCandidates) + aSecondCandidatesString := generateCandidateAddressStrings(connASecondCandidates) + bFirstCandidatesString := generateCandidateAddressStrings(connBFirstCandidates) + bSecondCandidatesString := generateCandidateAddressStrings(connBSecondCandidates) + require.NotEqual(t, aFirstCandidatesString, aSecondCandidatesString) + require.NotEqual(t, bFirstCandidatesString, bSecondCandidatesString) }) } @@ -1511,7 +1542,7 @@ func TestGetLocalCandidates(t *testing.T) { expectedCandidates = append(expectedCandidates, cand) - err = agent.addCandidate(context.Background(), cand, dummyConn) + err = agent.addCandidate(context.Background(), cand, dummyConn, agent.gatherGeneration) require.NoError(t, err) } @@ -2151,7 +2182,7 @@ func TestSetCandidatesUfrag(t *testing.T) { cand, errCand := NewCandidateHost(&cfg) require.NoError(t, errCand) - err = agent.addCandidate(context.Background(), cand, dummyConn) + err = agent.addCandidate(context.Background(), cand, dummyConn, agent.gatherGeneration) require.NoError(t, err) } @@ -2166,6 +2197,46 @@ func TestSetCandidatesUfrag(t *testing.T) { } } +func TestAddingCandidatesFromOtherGenerations(t *testing.T) { + var config AgentConfig + + agent, err := NewAgent(&config) + require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() + + agent.gatherGeneration = 3 + + dummyConn := &net.UDPConn{} + + for i := 0; i < 5; i++ { + cfg := CandidateHostConfig{ + Network: "udp", + Address: "192.168.0.2", + Port: 1000 + i, + Component: 1, + } + + cand, errCand := NewCandidateHost(&cfg) + require.NoError(t, errCand) + + err = agent.addCandidate(context.Background(), cand, dummyConn, uint64(i)) // nolint:gosec + require.NoError(t, err) + } + + actualCandidates, err := agent.GetLocalCandidates() + require.NoError(t, err) + require.Equal(t, 1, len(actualCandidates), "Only the candidate with a matching generation should be added") + + ext, ok := actualCandidates[0].GetExtension("generation") + require.True(t, ok) + + generation, err := strconv.ParseUint(ext.Value, 10, 64) + require.NoError(t, err) + require.Equal(t, agent.gatherGeneration, generation) +} + func TestAlwaysSentKeepAlive(t *testing.T) { //nolint:cyclop defer test.CheckRoutines(t)() diff --git a/gather.go b/gather.go index 1fc25e22..0dfeed8b 100644 --- a/gather.go +++ b/gather.go @@ -50,6 +50,7 @@ func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args .. // GatherCandidates initiates the trickle based gathering process. func (a *Agent) GatherCandidates() error { var gatherErr error + var generation uint64 if runErr := a.loop.Run(a.loop, func(ctx context.Context) { if a.gatheringState != GatheringStateNew { @@ -67,8 +68,10 @@ func (a *Agent) GatherCandidates() error { a.gatherCandidateCancel = cancel done := make(chan struct{}) a.gatherCandidateDone = done + generation = a.gatherGeneration + a.setGatheringStateLocked(GatheringStateGathering, generation) - go a.gatherCandidates(ctx, done) + go a.gatherCandidates(ctx, done, generation) }); runErr != nil { return runErr } @@ -76,19 +79,17 @@ func (a *Agent) GatherCandidates() error { return gatherErr } -func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) { //nolint:cyclop +func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}, generation uint64) { //nolint:cyclop defer close(done) - if err := a.setGatheringState(GatheringStateGathering); err != nil { //nolint:contextcheck - a.log.Warnf("Failed to set gatheringState to GatheringStateGathering: %v", err) - + if ctx.Err() != nil { return } - a.gatherCandidatesInternal(ctx) + a.gatherCandidatesInternal(ctx, generation) switch a.continualGatheringPolicy { case GatherOnce: - if err := a.setGatheringState(GatheringStateComplete); err != nil { //nolint:contextcheck + if err := a.setGatheringState(GatheringStateComplete, generation); err != nil { //nolint:contextcheck a.log.Warnf("Failed to set gatheringState to GatheringStateComplete: %v", err) } case GatherContinually: @@ -193,22 +194,22 @@ func (a *Agent) applyHostRewriteForUDPMux(candidateIPs []net.IP, udpAddr *net.UD } // gatherCandidatesInternal performs the actual candidate gathering for all configured types. -func (a *Agent) gatherCandidatesInternal(ctx context.Context) { +func (a *Agent) gatherCandidatesInternal(ctx context.Context, generation uint64) { var wg sync.WaitGroup for _, t := range a.candidateTypes { switch t { case CandidateTypeHost: wg.Add(1) go func() { - a.gatherCandidatesLocal(ctx, a.networkTypes) + a.gatherCandidatesLocal(ctx, a.networkTypes, generation) wg.Done() }() case CandidateTypeServerReflexive: - a.gatherServerReflexiveCandidates(ctx, &wg) + a.gatherServerReflexiveCandidates(ctx, &wg, generation) case CandidateTypeRelay: wg.Add(1) go func() { - a.gatherCandidatesRelay(ctx, a.urls) + a.gatherCandidatesRelay(ctx, a.urls, generation) wg.Done() }() case CandidateTypePeerReflexive, CandidateTypeUnspecified: @@ -219,15 +220,15 @@ func (a *Agent) gatherCandidatesInternal(ctx context.Context) { wg.Wait() } -func (a *Agent) gatherServerReflexiveCandidates(ctx context.Context, wg *sync.WaitGroup) { +func (a *Agent) gatherServerReflexiveCandidates(ctx context.Context, wg *sync.WaitGroup, generation uint64) { replaceSrflx := a.addressRewriteMapper != nil && a.addressRewriteMapper.shouldReplace(CandidateTypeServerReflexive) if !replaceSrflx { wg.Add(1) go func() { if a.udpMuxSrflx != nil { - a.gatherCandidatesSrflxUDPMux(ctx, a.urls, a.networkTypes) + a.gatherCandidatesSrflxUDPMux(ctx, a.urls, a.networkTypes, generation) } else { - a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes) + a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes, generation) } wg.Done() }() @@ -235,14 +236,14 @@ func (a *Agent) gatherServerReflexiveCandidates(ctx context.Context, wg *sync.Wa if a.addressRewriteMapper != nil && a.addressRewriteMapper.hasCandidateType(CandidateTypeServerReflexive) { wg.Add(1) go func() { - a.gatherCandidatesSrflxMapped(ctx, a.networkTypes) + a.gatherCandidatesSrflxMapped(ctx, a.networkTypes, generation) wg.Done() }() } } //nolint:gocognit,gocyclo,cyclop,maintidx -func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) { +func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType, generation uint64) { networks := map[string]struct{}{} for _, networkType := range networkTypes { if networkType.IsTCP() { @@ -254,7 +255,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ // When UDPMux is enabled, skip other UDP candidates if a.udpMux != nil { - if err := a.gatherCandidatesLocalUDPMux(ctx); err != nil { + if err := a.gatherCandidatesLocalUDPMux(ctx, generation); err != nil { a.log.Warnf("Failed to create host candidate for UDPMux: %s", err) } delete(networks, udp) @@ -415,7 +416,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ continue } - if err := a.addCandidate(ctx, candidateHost, connAndPort.conn); err != nil { + if err := a.addCandidate(ctx, candidateHost, connAndPort.conn, generation); err != nil { if closeErr := candidateHost.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } @@ -449,7 +450,7 @@ func shouldFilterLocationTracked(candidateIP net.IP) bool { return shouldFilterLocationTrackedIP(addr) } -func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit,cyclop +func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context, generation uint64) error { //nolint:gocognit,cyclop if a.udpMux == nil { return errUDPMuxDisabled } @@ -519,7 +520,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin continue } - if err := a.addCandidate(ctx, c, conn); err != nil { + if err := a.addCandidate(ctx, c, conn, generation); err != nil { if closeErr := c.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } @@ -536,7 +537,8 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin return nil } -func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []NetworkType) { //nolint:gocognit,cyclop +//nolint:gocognit,cyclop +func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []NetworkType, generation uint64) { var wg sync.WaitGroup defer wg.Wait() @@ -634,7 +636,7 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes [] continue } - if err := a.addCandidate(ctx, c, currentConn); err != nil { + if err := a.addCandidate(ctx, c, currentConn, generation); err != nil { if closeErr := c.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } @@ -652,7 +654,12 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes [] } //nolint:gocognit,cyclop -func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { +func (a *Agent) gatherCandidatesSrflxUDPMux( + ctx context.Context, + urls []*stun.URI, + networkTypes []NetworkType, + generation uint64, +) { var wg sync.WaitGroup defer wg.Wait() @@ -719,7 +726,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR return } - if err := a.addCandidate(ctx, c, conn); err != nil { + if err := a.addCandidate(ctx, c, conn, generation); err != nil { if closeErr := c.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } @@ -732,7 +739,9 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR } //nolint:cyclop,gocognit -func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { +func (a *Agent) gatherCandidatesSrflx( + ctx context.Context, urls []*stun.URI, networkTypes []NetworkType, generation uint64, +) { var wg sync.WaitGroup defer wg.Wait() @@ -812,7 +821,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net return } - if err := a.addCandidate(ctx, c, conn); err != nil { + if err := a.addCandidate(ctx, c, conn, generation); err != nil { if closeErr := c.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } @@ -824,7 +833,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net } //nolint:maintidx,gocognit,gocyclo,cyclop -func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { +func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI, generation uint64) { var wg sync.WaitGroup defer wg.Wait() @@ -1022,7 +1031,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { return } - a.addRelayCandidates(ctx, relayEndpoint{ + a.addRelayCandidates(ctx, generation, relayEndpoint{ network: network, address: rAddr.IP, port: rAddr.Port, @@ -1141,7 +1150,9 @@ func findIfaceForIP(ifaces []ifaceAddr, ip net.IP) string { return "" } -func (a *Agent) createRelayCandidate(ctx context.Context, ep relayEndpoint, ip net.IP, onClose func() error) error { +func (a *Agent) createRelayCandidate( + ctx context.Context, ep relayEndpoint, ip net.IP, generation uint64, onClose func() error, +) error { relayConfig := CandidateRelayConfig{ Network: ep.network, Component: ComponentRTP, @@ -1159,7 +1170,7 @@ func (a *Agent) createRelayCandidate(ctx context.Context, ep relayEndpoint, ip n return err } - if err := a.addCandidate(ctx, candidate, ep.conn); err != nil { + if err := a.addCandidate(ctx, candidate, ep.conn, generation); err != nil { if closeErr := candidate.close(); closeErr != nil { a.log.Warnf("Failed to close candidate: %v", closeErr) } @@ -1171,7 +1182,7 @@ func (a *Agent) createRelayCandidate(ctx context.Context, ep relayEndpoint, ip n return nil } -func (a *Agent) addRelayCandidates(ctx context.Context, ep relayEndpoint) { +func (a *Agent) addRelayCandidates(ctx context.Context, generation uint64, ep relayEndpoint) { if ep.conn == nil || ep.address == nil { return } @@ -1187,7 +1198,7 @@ func (a *Agent) addRelayCandidates(ctx context.Context, ep relayEndpoint) { onClose = nil } - if err := a.createRelayCandidate(ctx, ep, ip, onClose); err != nil { + if err := a.createRelayCandidate(ctx, ep, ip, generation, onClose); err != nil { if idx == 0 { if ep.closeConn != nil { ep.closeConn() @@ -1215,7 +1226,13 @@ func (a *Agent) startNetworkMonitoring(ctx context.Context) { return case <-ticker.C: if a.detectNetworkChanges() { - a.gatherCandidatesInternal(ctx) + err := a.loop.Run(ctx, func(_ context.Context) { + a.bumpGatheringGenerationLocked() + }) + if err != nil { + a.log.Warnf("failed to bump gathering generation on network change: %v", err) + } + a.gatherCandidatesInternal(ctx, a.gatherGeneration) } } } diff --git a/gather_test.go b/gather_test.go index 342cdc29..3641a71d 100644 --- a/gather_test.go +++ b/gather_test.go @@ -1088,7 +1088,7 @@ func TestGatherCandidatesRelayCallsAddRelayCandidates(t *testing.T) { } })) - agent.gatherCandidatesRelay(context.Background(), agent.urls) + agent.gatherCandidatesRelay(context.Background(), agent.urls, agent.gatherGeneration) var cand Candidate select { @@ -1151,7 +1151,7 @@ func TestGatherCandidatesRelayUsesTurnNet(t *testing.T) { } })) - agent.gatherCandidatesRelay(context.Background(), agent.urls) + agent.gatherCandidatesRelay(context.Background(), agent.urls, agent.gatherGeneration) select { case cand := <-candCh: @@ -1199,7 +1199,7 @@ func TestGatherCandidatesRelayDefaultClientError(t *testing.T) { } })) - agent.gatherCandidatesRelay(context.Background(), agent.urls) + agent.gatherCandidatesRelay(context.Background(), agent.urls, agent.gatherGeneration) select { case <-candidateCh: @@ -1337,7 +1337,7 @@ func TestGatherCandidatesSrflxMappedPortRangeError(t *testing.T) { agent.portMin = 9000 agent.portMax = 8000 - agent.gatherCandidatesSrflxMapped(context.Background(), []NetworkType{NetworkTypeUDP4}) + agent.gatherCandidatesSrflxMapped(context.Background(), []NetworkType{NetworkTypeUDP4}, agent.gatherGeneration) localCandidates, err := agent.GetLocalCandidates() require.NoError(t, err) @@ -1352,7 +1352,7 @@ func TestGatherCandidatesLocalUDPMux(t *testing.T) { require.NoError(t, agent.Close()) }() - err = agent.gatherCandidatesLocalUDPMux(context.Background()) + err = agent.gatherCandidatesLocalUDPMux(context.Background(), agent.gatherGeneration) require.ErrorIs(t, err, errUDPMuxDisabled) }) @@ -1373,7 +1373,7 @@ func TestGatherCandidatesLocalUDPMux(t *testing.T) { require.NoError(t, agent.OnCandidate(func(Candidate) {})) - err = agent.gatherCandidatesLocalUDPMux(context.Background()) + err = agent.gatherCandidatesLocalUDPMux(context.Background(), agent.gatherGeneration) require.NoError(t, err) candidates, err := agent.GetLocalCandidates() @@ -1414,7 +1414,9 @@ func TestGatherCandidatesSrflxUDPMux(t *testing.T) { require.NoError(t, agent.OnCandidate(func(Candidate) {})) - agent.gatherCandidatesSrflxUDPMux(context.Background(), []*stun.URI{stunURI}, []NetworkType{NetworkTypeUDP4}) + agent.gatherCandidatesSrflxUDPMux( + context.Background(), []*stun.URI{stunURI}, []NetworkType{NetworkTypeUDP4}, agent.gatherGeneration, + ) candidates, err := agent.GetLocalCandidates() require.NoError(t, err) @@ -2118,7 +2120,7 @@ func TestAddRelayCandidatesWithRewrite(t *testing.T) { agent.loop.Close() }) - agent.addRelayCandidates(ctx, ep) + agent.addRelayCandidates(ctx, agent.gatherGeneration, ep) cands := agent.localCandidates[NetworkTypeUDP4] require.Len(t, cands, 2) @@ -2144,7 +2146,7 @@ func TestAddRelayCandidatesSkipsNilConnOrAddress(t *testing.T) { ctx := context.Background() - agent.addRelayCandidates(ctx, relayEndpoint{ + agent.addRelayCandidates(ctx, agent.gatherGeneration, relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.IPv4(10, 0, 0, 1), port: 3478, @@ -2156,7 +2158,7 @@ func TestAddRelayCandidatesSkipsNilConnOrAddress(t *testing.T) { require.NoError(t, err) assert.Len(t, cands, 0) - agent.addRelayCandidates(ctx, relayEndpoint{ + agent.addRelayCandidates(ctx, agent.gatherGeneration, relayEndpoint{ network: NetworkTypeUDP4.String(), address: nil, port: 3478, @@ -2197,7 +2199,7 @@ func TestAddRelayCandidatesSkipsWhenResolveFails(t *testing.T) { agent.loop.Close() }) - agent.addRelayCandidates(context.Background(), relayEndpoint{ + agent.addRelayCandidates(context.Background(), agent.gatherGeneration, relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.IPv4(10, 0, 0, 2), port: 3478, @@ -2236,7 +2238,7 @@ func TestAddRelayCandidatesSkipsWhenResolveFails(t *testing.T) { agent.loop.Close() }) - agent.addRelayCandidates(context.Background(), relayEndpoint{ + agent.addRelayCandidates(context.Background(), agent.gatherGeneration, relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.IPv4(10, 0, 0, 3), port: 3478, @@ -2281,7 +2283,7 @@ func TestCreateRelayCandidateErrorPaths(t *testing.T) { }, } - agent.addRelayCandidates(context.Background(), ep) + agent.addRelayCandidates(context.Background(), agent.gatherGeneration, ep) cands, err := agent.GetLocalCandidates() require.NoError(t, err) @@ -2309,7 +2311,7 @@ func TestCreateRelayCandidateErrorPaths(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // force addCandidate to fail - agent.addRelayCandidates(ctx, relayEndpoint{ + agent.addRelayCandidates(ctx, agent.gatherGeneration, relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.IPv4(10, 0, 0, 5), port: 3478, @@ -2348,7 +2350,7 @@ func TestGatherCandidatesLocalTCPMuxSkipsUnboundInterfaces(t *testing.T) { }) require.NoError(t, agent.OnCandidate(func(Candidate) {})) - agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeTCP4}) + agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeTCP4}, agent.gatherGeneration) cands, err := agent.GetLocalCandidates() require.NoError(t, err) @@ -2373,7 +2375,7 @@ func TestGatherCandidatesLocalHostErrorPaths(t *testing.T) { }) require.NoError(t, agent.OnCandidate(func(Candidate) {})) - assert.NoError(t, agent.gatherCandidatesLocalUDPMux(context.Background())) + assert.NoError(t, agent.gatherCandidatesLocalUDPMux(context.Background(), agent.gatherGeneration)) assert.True(t, mux.conn.closed) cands, err := agent.GetLocalCandidates() @@ -2399,7 +2401,7 @@ func TestGatherCandidatesLocalHostErrorPaths(t *testing.T) { agent.includeLoopback = true agent.mDNSName = "invalid-mdns" // no .local suffix -> NewCandidateHost parse fails - agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeUDP4}) + agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeUDP4}, agent.gatherGeneration) cands, err := agent.GetLocalCandidates() require.NoError(t, err) @@ -2425,7 +2427,7 @@ func TestGatherCandidatesLocalHostErrorPaths(t *testing.T) { agent.loop.Close() - agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeUDP4}) + agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeUDP4}, agent.gatherGeneration) agent.loop.Run(agent.loop, func(context.Context) { //nolint:errcheck,gosec assert.Empty(t, agent.localCandidates[NetworkTypeUDP4]) @@ -2459,7 +2461,7 @@ func TestGatherCandidatesLocalHostErrorPaths(t *testing.T) { agent.loop.Close() }) - agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeUDP4}) + agent.gatherCandidatesLocal(context.Background(), []NetworkType{NetworkTypeUDP4}, agent.gatherGeneration) cands, err := agent.GetLocalCandidates() require.NoError(t, err) @@ -2488,7 +2490,7 @@ func TestGatherCandidatesLocalHostErrorPaths(t *testing.T) { }) require.NoError(t, agent.OnCandidate(func(Candidate) {})) - require.NoError(t, agent.gatherCandidatesLocalUDPMux(context.Background())) + require.NoError(t, agent.gatherCandidatesLocalUDPMux(context.Background(), agent.gatherGeneration)) cands, err := agent.GetLocalCandidates() require.NoError(t, err) @@ -2986,7 +2988,7 @@ func TestGatherAddressRewriteRelayModes(t *testing.T) { require.NoError(t, agent.Close()) }) - agent.addRelayCandidates(context.Background(), relayEndpoint{ + agent.addRelayCandidates(context.Background(), agent.gatherGeneration, relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.ParseIP("192.0.2.10"), port: 5000, @@ -3020,7 +3022,7 @@ func TestGatherAddressRewriteRelayModes(t *testing.T) { require.NoError(t, agent.Close()) }) - agent.addRelayCandidates(context.Background(), relayEndpoint{ + agent.addRelayCandidates(context.Background(), agent.gatherGeneration, relayEndpoint{ network: NetworkTypeUDP4.String(), address: net.ParseIP("192.0.2.20"), port: 6000, @@ -3280,7 +3282,7 @@ func TestGatherCandidatesSrflxMappedMissingExternalIPs(t *testing.T) { }, } - agent.gatherCandidatesSrflxMapped(context.Background(), []NetworkType{NetworkTypeUDP4}) + agent.gatherCandidatesSrflxMapped(context.Background(), []NetworkType{NetworkTypeUDP4}, agent.gatherGeneration) localCandidates, err := agent.GetLocalCandidates() require.NoError(t, err) diff --git a/gather_vnet_test.go b/gather_vnet_test.go index 09287996..02f48c20 100644 --- a/gather_vnet_test.go +++ b/gather_vnet_test.go @@ -555,5 +555,5 @@ func TestVNetGather_TURNConnectionLeak(t *testing.T) { require.NoError(t, aAgent.Close()) }() - aAgent.gatherCandidatesRelay(context.Background(), []*stun.URI{turnServerURL}) + aAgent.gatherCandidatesRelay(context.Background(), []*stun.URI{turnServerURL}, aAgent.gatherGeneration) }