Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 63 additions & 18 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"math"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -58,8 +59,10 @@ type Agent struct {
tieBreaker uint64
lite bool

connectionState ConnectionState
gatheringState GatheringState
connectionState ConnectionState
gatheringState GatheringState
gatherGeneration uint64
gatherEndSent bool

mDNSMode MulticastDNSMode
mDNSName string
Expand Down Expand Up @@ -1052,28 +1055,39 @@ func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop
a.requestConnectivityCheck()
}

func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn) error {
func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn, gen uint64) error {
if err := ctx.Err(); err != nil {
return err
}

cleanupCandidate := func(reason string) {
if err := cand.close(); err != nil {
a.log.Warnf("Failed to close %s candidate: %v", reason, err)
}
if err := candidateConn.Close(); err != nil {
a.log.Warnf("Failed to close %s candidate connection: %v", reason, err)
}
}

return a.loop.Run(ctx, func(context.Context) {
if a.gatherGeneration != gen {
a.log.Debugf("Ignoring candidate from different gather generation (a: %d c: %d)", a.gatherGeneration, gen)
cleanupCandidate("old")

return
}

set := a.localCandidates[cand.NetworkType()]
for _, candidate := range set {
if candidate.Equal(cand) {
a.log.Debugf("Ignore duplicate candidate: %s", cand)
if err := cand.close(); err != nil {
a.log.Warnf("Failed to close duplicate candidate: %v", err)
}
if err := candidateConn.Close(); err != nil {
a.log.Warnf("Failed to close duplicate candidate connection: %v", err)
}
cleanupCandidate("duplicate")

return
}
}

a.setCandidateExtensions(cand)
a.setCandidateExtensions(cand, gen)
cand.start(a, candidateConn, a.startedCh)

set = append(set, cand)
Expand All @@ -1093,14 +1107,22 @@ func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn
})
}

func (a *Agent) setCandidateExtensions(cand Candidate) {
func (a *Agent) setCandidateExtensions(cand Candidate, candidateGeneration uint64) {
err := cand.AddExtension(CandidateExtension{
Key: "ufrag",
Value: a.localUfrag,
})
if err != nil {
a.log.Errorf("Failed to add ufrag extension to candidate: %v", err)
}

err = cand.AddExtension(CandidateExtension{
Key: "generation",
Value: strconv.FormatUint(candidateGeneration, 10),
})
if err != nil {
a.log.Errorf("Failed to add generation extension to candidate: %v", err)
}
}

// GetRemoteCandidates returns the remote candidates.
Expand Down Expand Up @@ -1637,14 +1659,17 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop
if a.gatheringState == GatheringStateGathering {
a.gatherCandidateCancel()
}
if a.gatheringState != GatheringStateNew {
a.setGatheringStateLocked(GatheringStateComplete, a.gatherGeneration)
}
a.bumpGatheringGenerationLocked()

// Clear all agent needed to take back to fresh state
a.removeUfragFromMux()
a.localUfrag = ufrag
a.localPwd = pwd
a.remoteUfrag = ""
a.remotePwd = ""
a.gatheringState = GatheringStateNew
a.checklist = make([]*CandidatePair, 0)
a.pairsByID = make(map[uint64]*CandidatePair)
a.pendingBindingRequests = make([]bindingRequest, 0)
Expand All @@ -1664,14 +1689,10 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop
return err
}

func (a *Agent) setGatheringState(newState GatheringState) error {
func (a *Agent) setGatheringState(newState GatheringState, generation uint64) error {
done := make(chan struct{})
if err := a.loop.Run(a.loop, func(context.Context) {
if a.gatheringState != newState && newState == GatheringStateComplete {
a.candidateNotifier.EnqueueCandidate(nil)
}

a.gatheringState = newState
a.setGatheringStateLocked(newState, generation)
close(done)
}); err != nil {
return err
Expand All @@ -1682,6 +1703,30 @@ func (a *Agent) setGatheringState(newState GatheringState) error {
return nil
}

func (a *Agent) setGatheringStateLocked(newState GatheringState, generation uint64) {
if generation != a.gatherGeneration {
return
}

prevState := a.gatheringState
if prevState == newState {
return
}

if newState == GatheringStateComplete && !a.gatherEndSent {
a.candidateNotifier.EnqueueCandidate(nil)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoTurk I wonder if we should also be concerned about calling this from within the "Lock" / AgentLoop here. Specifically the fact that EnqueueCandidate starts by acquiring the handlerNotifier Lock. I'd be concerned that we could end up delaying the Agent Loop if not fully deadlocking it.

a.gatherEndSent = true
}

a.gatheringState = newState
}

func (a *Agent) bumpGatheringGenerationLocked() {
a.gatherGeneration++
a.gatherEndSent = false
a.gatheringState = GatheringStateNew
}

func (a *Agent) needsToCheckPriorityOnNominated() bool {
return !a.lite || a.enableUseCandidateCheckPriority
}
Expand Down
89 changes: 80 additions & 9 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1385,9 +1385,7 @@ func TestAgentRestart(t *testing.T) {

t.Run("Restart Both Sides", func(t *testing.T) {
// Get all addresses of candidates concatenated
generateCandidateAddressStrings := func(candidates []Candidate, err error) (out string) {
require.NoError(t, err)

generateCandidateAddressStrings := func(candidates []Candidate) (out string) {
for _, c := range candidates {
out += c.Address() + ":"
out += strconv.Itoa(c.Port())
Expand All @@ -1396,14 +1394,31 @@ func TestAgentRestart(t *testing.T) {
return
}

candidateHasGeneration := func(generation uint64, candidate Candidate) {
genString := strconv.FormatUint(generation, 10)
ext, ok := candidate.GetExtension("generation")

require.True(t, ok)
require.Equal(t, genString, ext.Value)
}

// Store the original candidates, confirm that after we reconnect we have new pairs
connA, connB := pipe(t, &AgentConfig{
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
})
defer closePipe(t, connA, connB)
connAFirstCandidates := generateCandidateAddressStrings(connA.agent.GetLocalCandidates())
connBFirstCandidates := generateCandidateAddressStrings(connB.agent.GetLocalCandidates())

aFirstGeneration := connA.agent.gatherGeneration
bFirstGeneration := connB.agent.gatherGeneration

connAFirstCandidates, err := connA.agent.GetLocalCandidates()
require.NoError(t, err)
connBFirstCandidates, err := connB.agent.GetLocalCandidates()
require.NoError(t, err)

candidateHasGeneration(aFirstGeneration, connAFirstCandidates[0])
candidateHasGeneration(bFirstGeneration, connBFirstCandidates[0])

aNotifier, aConnected := onConnected()
require.NoError(t, connA.agent.OnConnectionStateChange(aNotifier))
Expand All @@ -1415,6 +1430,10 @@ func TestAgentRestart(t *testing.T) {
require.NoError(t, connA.agent.Restart("", ""))
require.NoError(t, connB.agent.Restart("", ""))

// Generation should change after Restart call
require.NotEqual(t, aFirstGeneration, connA.agent.gatherGeneration)
require.NotEqual(t, bFirstGeneration, connB.agent.gatherGeneration)

// Exchange Candidates and Credentials
ufrag, pwd, err := connB.agent.GetLocalUserCredentials()
require.NoError(t, err)
Expand All @@ -1430,9 +1449,21 @@ func TestAgentRestart(t *testing.T) {
<-aConnected
<-bConnected

connASecondCandidates, err := connA.agent.GetLocalCandidates()
require.NoError(t, err)
connBSecondCandidates, err := connB.agent.GetLocalCandidates()
require.NoError(t, err)

candidateHasGeneration(connA.agent.gatherGeneration, connASecondCandidates[0])
candidateHasGeneration(connB.agent.gatherGeneration, connASecondCandidates[0])

// Assert that we have new candidates each time
require.NotEqual(t, connAFirstCandidates, generateCandidateAddressStrings(connA.agent.GetLocalCandidates()))
require.NotEqual(t, connBFirstCandidates, generateCandidateAddressStrings(connB.agent.GetLocalCandidates()))
aFirstCandidatesString := generateCandidateAddressStrings(connAFirstCandidates)
aSecondCandidatesString := generateCandidateAddressStrings(connASecondCandidates)
bFirstCandidatesString := generateCandidateAddressStrings(connBFirstCandidates)
bSecondCandidatesString := generateCandidateAddressStrings(connBSecondCandidates)
require.NotEqual(t, aFirstCandidatesString, aSecondCandidatesString)
require.NotEqual(t, bFirstCandidatesString, bSecondCandidatesString)
})
}

Expand Down Expand Up @@ -1511,7 +1542,7 @@ func TestGetLocalCandidates(t *testing.T) {

expectedCandidates = append(expectedCandidates, cand)

err = agent.addCandidate(context.Background(), cand, dummyConn)
err = agent.addCandidate(context.Background(), cand, dummyConn, agent.gatherGeneration)
require.NoError(t, err)
}

Expand Down Expand Up @@ -2151,7 +2182,7 @@ func TestSetCandidatesUfrag(t *testing.T) {
cand, errCand := NewCandidateHost(&cfg)
require.NoError(t, errCand)

err = agent.addCandidate(context.Background(), cand, dummyConn)
err = agent.addCandidate(context.Background(), cand, dummyConn, agent.gatherGeneration)
require.NoError(t, err)
}

Expand All @@ -2166,6 +2197,46 @@ func TestSetCandidatesUfrag(t *testing.T) {
}
}

func TestAddingCandidatesFromOtherGenerations(t *testing.T) {
var config AgentConfig

agent, err := NewAgent(&config)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()

agent.gatherGeneration = 3

dummyConn := &net.UDPConn{}

for i := 0; i < 5; i++ {
cfg := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 1000 + i,
Component: 1,
}

cand, errCand := NewCandidateHost(&cfg)
require.NoError(t, errCand)

err = agent.addCandidate(context.Background(), cand, dummyConn, uint64(i)) // nolint:gosec
require.NoError(t, err)
}

actualCandidates, err := agent.GetLocalCandidates()
require.NoError(t, err)
require.Equal(t, 1, len(actualCandidates), "Only the candidate with a matching generation should be added")

ext, ok := actualCandidates[0].GetExtension("generation")
require.True(t, ok)

generation, err := strconv.ParseUint(ext.Value, 10, 64)
require.NoError(t, err)
require.Equal(t, agent.gatherGeneration, generation)
}

func TestAlwaysSentKeepAlive(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()

Expand Down
Loading
Loading