Skip to content

Commit e09ef04

Browse files
committed
store map of peers supporting DialProtocol
1 parent bd38997 commit e09ef04

File tree

4 files changed

+145
-90
lines changed

4 files changed

+145
-90
lines changed

p2p/host/blank/blank.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ func NewBlankHost(n network.Network, options ...Option) *BlankHost {
6363
}
6464

6565
bh := &BlankHost{
66-
n: n,
67-
cmgr: cfg.cmgr,
68-
mux: mstream.NewMultistreamMuxer[protocol.ID](),
66+
n: n,
67+
cmgr: cfg.cmgr,
68+
mux: mstream.NewMultistreamMuxer[protocol.ID](),
69+
eventbus: cfg.eventBus,
6970
}
7071
if bh.eventbus == nil {
7172
bh.eventbus = eventbus.NewBus(eventbus.WithMetricsTracer(eventbus.NewMetricsTracer()))

p2p/protocol/autonatv2/autonat.go

+50-37
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pbv2"
1616
ma "github.com/multiformats/go-multiaddr"
1717
manet "github.com/multiformats/go-multiaddr/net"
18-
"golang.org/x/exp/rand"
1918
)
2019

2120
const (
@@ -45,6 +44,8 @@ type AutoNAT struct {
4544
wg sync.WaitGroup
4645
srv *Server
4746
cli *Client
47+
mx sync.Mutex
48+
peers map[peer.ID]struct{}
4849
allowAllAddrs bool // for testing
4950
}
5051

@@ -55,7 +56,11 @@ func New(h host.Host, dialer host.Host, opts ...AutoNATOption) (*AutoNAT, error)
5556
return nil, err
5657
}
5758
}
58-
sub, err := h.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged))
59+
sub, err := h.EventBus().Subscribe([]interface{}{
60+
new(event.EvtLocalReachabilityChanged),
61+
new(event.EvtPeerProtocolsUpdated),
62+
new(event.EvtPeerConnectednessChanged),
63+
})
5964
if err != nil {
6065
return nil, fmt.Errorf("failed to subscribe to event.EvtLocalReachabilityChanged: %w", err)
6166
}
@@ -69,6 +74,7 @@ func New(h host.Host, dialer host.Host, opts ...AutoNATOption) (*AutoNAT, error)
6974
srv: NewServer(h, dialer, s),
7075
cli: NewClient(h),
7176
allowAllAddrs: s.allowAllAddrs,
77+
peers: make(map[peer.ID]struct{}),
7278
}
7379
an.cli.Register()
7480

@@ -84,28 +90,31 @@ func (an *AutoNAT) background() {
8490
an.srv.Disable()
8591
an.wg.Done()
8692
return
87-
case evt := <-an.sub.Out():
88-
// Enable the server only if we're publicly reachable.
89-
//
90-
// Currently this event is sent by the AutoNAT v1 module. During the
91-
// transition period from AutoNAT v1 to v2, there won't be enough v2 servers
92-
// on the network and most clients will be unable to discover a peer which
93-
// supports AutoNAT v2. So, we use v1 to determine reachability for the
94-
// transition period.
95-
//
96-
// Once there are enough v2 servers on the network for nodes to determine
97-
// their reachability using AutoNAT v2, we'll use Address Pipeline
98-
// (https://github.com/libp2p/go-libp2p/issues/2229)(to be implemented in a
99-
// future release) to determine reachability using v2 client and send this
100-
// event if we are publicly reachable.
101-
revt, ok := evt.(event.EvtLocalReachabilityChanged)
102-
if !ok {
103-
log.Errorf("Unexpected event %s of type %T", evt, evt)
104-
}
105-
if revt.Reachability == network.ReachabilityPrivate {
106-
an.srv.Disable()
107-
} else {
108-
an.srv.Enable()
93+
case e := <-an.sub.Out():
94+
switch evt := e.(type) {
95+
case event.EvtLocalReachabilityChanged:
96+
// Enable the server only if we're publicly reachable.
97+
//
98+
// Currently this event is sent by the AutoNAT v1 module. During the
99+
// transition period from AutoNAT v1 to v2, there won't be enough v2 servers
100+
// on the network and most clients will be unable to discover a peer which
101+
// supports AutoNAT v2. So, we use v1 to determine reachability for the
102+
// transition period.
103+
//
104+
// Once there are enough v2 servers on the network for nodes to determine
105+
// their reachability using AutoNAT v2, we'll use Address Pipeline
106+
// (https://github.com/libp2p/go-libp2p/issues/2229)(to be implemented in a
107+
// future release) to determine reachability using v2 client and send this
108+
// event from Address Pipeline, if we are publicly reachable.
109+
if evt.Reachability == network.ReachabilityPrivate {
110+
an.srv.Disable()
111+
} else {
112+
an.srv.Enable()
113+
}
114+
case event.EvtPeerProtocolsUpdated:
115+
an.updatePeer(evt.Peer)
116+
case event.EvtPeerConnectednessChanged:
117+
an.updatePeer(evt.Peer)
109118
}
110119
}
111120
}
@@ -140,21 +149,25 @@ func (an *AutoNAT) CheckReachability(ctx context.Context, highPriorityAddrs []ma
140149
}
141150

142151
func (an *AutoNAT) validPeer() peer.ID {
143-
peers := an.host.Peerstore().Peers()
144-
idx := 0
145-
for i := 0; i < len(peers); i++ {
146-
if proto, err := an.host.Peerstore().SupportsProtocols(peers[i], DialProtocol); len(proto) == 0 || err != nil {
147-
continue
148-
}
149-
peers[idx] = peers[i]
150-
idx++
152+
an.mx.Lock()
153+
defer an.mx.Unlock()
154+
for p := range an.peers {
155+
return p
151156
}
152-
if idx == 0 {
153-
return ""
157+
return ""
158+
}
159+
160+
func (an *AutoNAT) updatePeer(p peer.ID) {
161+
an.mx.Lock()
162+
defer an.mx.Unlock()
163+
164+
_, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol)
165+
connState := an.host.Network().Connectedness(p)
166+
if err == nil && connState == network.Connected {
167+
an.peers[p] = struct{}{}
168+
} else {
169+
delete(an.peers, p)
154170
}
155-
peers = peers[:idx]
156-
rand.Shuffle(len(peers), func(i, j int) { peers[i], peers[j] = peers[j], peers[i] })
157-
return peers[0]
158171
}
159172

160173
type Result struct {

p2p/protocol/autonatv2/autonat_test.go

+87-39
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"reflect"
7+
"sync/atomic"
78
"testing"
89
"time"
910

@@ -12,6 +13,7 @@ import (
1213
"github.com/libp2p/go-libp2p/core/peer"
1314
"github.com/libp2p/go-libp2p/core/peerstore"
1415
bhost "github.com/libp2p/go-libp2p/p2p/host/blank"
16+
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
1517
swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
1618
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pbv2"
1719

@@ -22,14 +24,17 @@ import (
2224

2325
func newAutoNAT(t *testing.T, dialer host.Host, opts ...AutoNATOption) *AutoNAT {
2426
t.Helper()
25-
h := bhost.NewBlankHost(swarmt.GenSwarm(t))
27+
b := eventbus.NewBus()
28+
h := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.EventBus(b)), bhost.WithEventBus(b))
2629
if dialer == nil {
2730
dialer = bhost.NewBlankHost(swarmt.GenSwarm(t))
2831
}
2932
an, err := New(h, dialer, opts...)
3033
if err != nil {
3134
t.Error(err)
3235
}
36+
an.srv.Enable()
37+
an.cli.Register()
3338
return an
3439
}
3540

@@ -47,31 +52,28 @@ func parseAddrs(t *testing.T, msg *pbv2.Message) []ma.Multiaddr {
4752
return addrs
4853
}
4954

50-
func TestValidPeer(t *testing.T) {
51-
an := newAutoNAT(t, nil)
52-
require.Equal(t, an.validPeer(), peer.ID(""))
53-
an.host.Peerstore().AddAddr("peer1", ma.StringCast("/ip4/127.0.0.1/tcp/1"), peerstore.PermanentAddrTTL)
54-
an.host.Peerstore().AddAddr("peer2", ma.StringCast("/ip4/127.0.0.1/tcp/2"), peerstore.PermanentAddrTTL)
55-
require.NoError(t, an.host.Peerstore().AddProtocols("peer1", DialProtocol))
56-
require.NoError(t, an.host.Peerstore().AddProtocols("peer2", DialProtocol))
57-
58-
var got1, got2 bool
59-
for i := 0; i < 100; i++ {
60-
p := an.validPeer()
61-
switch p {
62-
case peer.ID("peer1"):
63-
got1 = true
64-
case peer.ID("peer2"):
65-
got2 = true
66-
default:
67-
t.Fatalf("invalid peer: %s", p)
68-
}
69-
if got1 && got2 {
70-
break
71-
}
72-
}
73-
require.True(t, got1)
74-
require.True(t, got2)
55+
func idAndConnect(t *testing.T, a, b host.Host) {
56+
a.Peerstore().AddAddrs(b.ID(), b.Addrs(), peerstore.PermanentAddrTTL)
57+
a.Peerstore().AddProtocols(b.ID(), DialProtocol)
58+
59+
err := a.Connect(context.Background(), peer.AddrInfo{ID: b.ID()})
60+
require.NoError(t, err)
61+
}
62+
63+
// waitForPeer waits for a to process all peer events
64+
func waitForPeer(t *testing.T, a *AutoNAT) {
65+
t.Helper()
66+
require.Eventually(t, func() bool {
67+
a.mx.Lock()
68+
defer a.mx.Unlock()
69+
return len(a.peers) > 0
70+
}, 5*time.Second, 100*time.Millisecond)
71+
}
72+
73+
// identify provides server address and protocol to client
74+
func identify(t *testing.T, cli *AutoNAT, srv *AutoNAT) {
75+
idAndConnect(t, cli.host, srv.host)
76+
waitForPeer(t, cli)
7577
}
7678

7779
func TestAutoNATPrivateAddr(t *testing.T) {
@@ -82,19 +84,24 @@ func TestAutoNATPrivateAddr(t *testing.T) {
8284
}
8385

8486
func TestClientRequest(t *testing.T) {
85-
an := newAutoNAT(t, nil)
87+
an := newAutoNAT(t, nil, allowAll)
8688

8789
addrs := an.host.Addrs()
8890

91+
var gotReq atomic.Bool
8992
p := bhost.NewBlankHost(swarmt.GenSwarm(t))
9093
p.SetStreamHandler(DialProtocol, func(s network.Stream) {
94+
gotReq.Store(true)
9195
r := pbio.NewDelimitedReader(s, maxMsgSize)
9296
var msg pbv2.Message
93-
err := r.ReadMsg(&msg)
94-
if err != nil {
97+
if err := r.ReadMsg(&msg); err != nil {
9598
t.Error(err)
99+
return
100+
}
101+
if msg.GetDialRequest() == nil {
102+
t.Errorf("expected message to be of type DialRequest, got %T", msg.Msg)
103+
return
96104
}
97-
require.NotNil(t, msg.GetDialRequest())
98105
addrsb := make([][]byte, len(addrs))
99106
for i := 0; i < len(addrs); i++ {
100107
addrsb[i] = addrs[i].Bytes()
@@ -105,20 +112,23 @@ func TestClientRequest(t *testing.T) {
105112
s.Reset()
106113
})
107114

108-
an.host.Peerstore().AddAddrs(p.ID(), p.Addrs(), peerstore.TempAddrTTL)
109-
an.host.Peerstore().AddProtocols(p.ID(), DialProtocol)
115+
idAndConnect(t, an.host, p)
116+
waitForPeer(t, an)
117+
110118
res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:])
111119
require.Nil(t, res)
112120
require.NotNil(t, err)
121+
require.True(t, gotReq.Load())
113122
}
114123

115124
func TestClientServerError(t *testing.T) {
116125
an := newAutoNAT(t, nil, allowAll)
117126
addrs := an.host.Addrs()
118127

119128
p := bhost.NewBlankHost(swarmt.GenSwarm(t))
120-
an.host.Peerstore().AddAddrs(p.ID(), p.Addrs(), peerstore.PermanentAddrTTL)
121-
an.host.Peerstore().AddProtocols(p.ID(), DialProtocol)
129+
idAndConnect(t, an.host, p)
130+
waitForPeer(t, an)
131+
122132
done := make(chan bool)
123133
tests := []struct {
124134
handler func(network.Stream)
@@ -163,8 +173,9 @@ func TestClientDataRequest(t *testing.T) {
163173
addrs := an.host.Addrs()
164174

165175
p := bhost.NewBlankHost(swarmt.GenSwarm(t))
166-
an.host.Peerstore().AddAddrs(p.ID(), p.Addrs(), peerstore.PermanentAddrTTL)
167-
an.host.Peerstore().AddProtocols(p.ID(), DialProtocol)
176+
idAndConnect(t, an.host, p)
177+
waitForPeer(t, an)
178+
168179
done := make(chan bool)
169180
tests := []struct {
170181
handler func(network.Stream)
@@ -234,9 +245,8 @@ func TestClientDialAttempts(t *testing.T) {
234245
addrs := an.host.Addrs()
235246

236247
p := bhost.NewBlankHost(swarmt.GenSwarm(t))
237-
an.host.Peerstore().AddAddrs(p.ID(), p.Addrs(), peerstore.PermanentAddrTTL)
238-
an.host.Peerstore().AddProtocols(p.ID(), DialProtocol)
239-
an.cli.Register()
248+
idAndConnect(t, an.host, p)
249+
waitForPeer(t, an)
240250

241251
tests := []struct {
242252
handler func(network.Stream)
@@ -419,3 +429,41 @@ func TestClientDialAttempts(t *testing.T) {
419429
})
420430
}
421431
}
432+
433+
func TestEventSubscription(t *testing.T) {
434+
an := newAutoNAT(t, nil)
435+
defer an.host.Close()
436+
437+
b := bhost.NewBlankHost(swarmt.GenSwarm(t))
438+
defer b.Close()
439+
c := bhost.NewBlankHost(swarmt.GenSwarm(t))
440+
defer c.Close()
441+
442+
idAndConnect(t, an.host, b)
443+
require.Eventually(t, func() bool {
444+
an.mx.Lock()
445+
defer an.mx.Unlock()
446+
return len(an.peers) == 1
447+
}, 5*time.Second, 100*time.Millisecond)
448+
449+
idAndConnect(t, an.host, c)
450+
require.Eventually(t, func() bool {
451+
an.mx.Lock()
452+
defer an.mx.Unlock()
453+
return len(an.peers) == 2
454+
}, 5*time.Second, 100*time.Millisecond)
455+
456+
an.host.Network().ClosePeer(b.ID())
457+
require.Eventually(t, func() bool {
458+
an.mx.Lock()
459+
defer an.mx.Unlock()
460+
return len(an.peers) == 1
461+
}, 5*time.Second, 100*time.Millisecond)
462+
463+
an.host.Network().ClosePeer(c.ID())
464+
require.Eventually(t, func() bool {
465+
an.mx.Lock()
466+
defer an.mx.Unlock()
467+
return len(an.peers) == 0
468+
}, 5*time.Second, 100*time.Millisecond)
469+
}

0 commit comments

Comments
 (0)