Skip to content

Commit d45e481

Browse files
Merge pull request #272 from libp2p/clean-up-dial-sync
simplify the DialSync code
2 parents bf044ff + 0e0111c commit d45e481

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

p2p/net/swarm/dial_sync.go

+20-25
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,26 @@ func newDialSync(worker dialWorkerFunc) *DialSync {
2222
// DialSync is a dial synchronization helper that ensures that at most one dial
2323
// to any given peer is active at any given time.
2424
type DialSync struct {
25+
mutex sync.Mutex
2526
dials map[peer.ID]*activeDial
26-
dialsLk sync.Mutex
2727
dialWorker dialWorkerFunc
2828
}
2929

3030
type activeDial struct {
31-
id peer.ID
3231
refCnt int
3332

3433
ctx context.Context
3534
cancel func()
3635

3736
reqch chan dialRequest
38-
39-
ds *DialSync
4037
}
4138

42-
func (ad *activeDial) decref() {
43-
ad.ds.dialsLk.Lock()
44-
ad.refCnt--
45-
if ad.refCnt == 0 {
46-
ad.cancel()
47-
close(ad.reqch)
48-
delete(ad.ds.dials, ad.id)
49-
}
50-
ad.ds.dialsLk.Unlock()
39+
func (ad *activeDial) close() {
40+
ad.cancel()
41+
close(ad.reqch)
5142
}
5243

53-
func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
44+
func (ad *activeDial) dial(ctx context.Context) (*Conn, error) {
5445
dialCtx := ad.ctx
5546

5647
if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
@@ -76,30 +67,26 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
7667
}
7768

7869
func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) {
79-
ds.dialsLk.Lock()
80-
defer ds.dialsLk.Unlock()
70+
ds.mutex.Lock()
71+
defer ds.mutex.Unlock()
8172

8273
actd, ok := ds.dials[p]
8374
if !ok {
8475
// This code intentionally uses the background context. Otherwise, if the first call
8576
// to Dial is canceled, subsequent dial calls will also be canceled.
8677
// XXX: this also breaks direct connection logic. We will need to pipe the
8778
// information through some other way.
88-
adctx, cancel := context.WithCancel(context.Background())
79+
ctx, cancel := context.WithCancel(context.Background())
8980
actd = &activeDial{
90-
id: p,
91-
ctx: adctx,
81+
ctx: ctx,
9282
cancel: cancel,
9383
reqch: make(chan dialRequest),
94-
ds: ds,
9584
}
9685
go ds.dialWorker(p, actd.reqch)
9786
ds.dials[p] = actd
9887
}
99-
100-
// increase ref count before dropping dialsLk
88+
// increase ref count before dropping mutex
10189
actd.refCnt++
102-
10390
return actd, nil
10491
}
10592

@@ -111,6 +98,14 @@ func (ds *DialSync) Dial(ctx context.Context, p peer.ID) (*Conn, error) {
11198
return nil, err
11299
}
113100

114-
defer ad.decref()
115-
return ad.dial(ctx, p)
101+
defer func() {
102+
ds.mutex.Lock()
103+
defer ds.mutex.Unlock()
104+
ad.refCnt--
105+
if ad.refCnt == 0 {
106+
ad.close()
107+
delete(ds.dials, p)
108+
}
109+
}()
110+
return ad.dial(ctx)
116111
}

0 commit comments

Comments
 (0)