Skip to content

Commit d8f93fc

Browse files
committed
swarm: fix DialPeer behaviour for transient connections
This fixes a bug where the first call to `swarm.DialPeer` succeeds and returns a transient connection with no error while the second call to DialPeer returns `(nil, network.ErrTransientConn)`. For dialing, we now only rely on `network.WithForceDirectDial` to force a direct connection. For new stream, we open a stream on a transient connection only if `network.WithUseTransient` is used.
1 parent 1153b1b commit d8f93fc

File tree

4 files changed

+88
-30
lines changed

4 files changed

+88
-30
lines changed

p2p/net/swarm/dial_worker.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ loop:
159159
// Enqueue the peer's addresses relevant to this request in dq and
160160
// track dials to the addresses relevant to this request.
161161

162-
c, err := w.s.bestAcceptableConnToPeer(req.ctx, w.peer)
163-
if c != nil || err != nil {
164-
req.resch <- dialResponse{conn: c, err: err}
162+
c := w.s.bestAcceptableConnToPeer(req.ctx, w.peer)
163+
if c != nil {
164+
req.resch <- dialResponse{conn: c}
165165
continue loop
166166
}
167167

@@ -373,7 +373,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) {
373373
// all addrs have erred, dispatch dial error
374374
// but first do a last one check in case an acceptable connection has landed from
375375
// a simultaneous dial that started later and added new acceptable addrs
376-
c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer)
376+
c := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer)
377377
if c != nil {
378378
pr.req.resch <- dialResponse{conn: c}
379379
} else {

p2p/net/swarm/swarm.go

+9-20
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ func (s *Swarm) StreamHandler() network.StreamHandler {
428428
}
429429

430430
// NewStream creates a new stream on any available connection to peer, dialing
431-
// if necessary.
431+
// if necessary. Use network.WithUseTransient to open a stream over a transient(relayed)
432+
// connection.
432433
func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error) {
433434
log.Debugf("[%s] opening stream to peer [%s]", s.local, p)
434435

@@ -447,10 +448,7 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error
447448
dials := 0
448449
for {
449450
// will prefer direct connections over relayed connections for opening streams
450-
c, err := s.bestAcceptableConnToPeer(ctx, p)
451-
if err != nil {
452-
return nil, err
453-
}
451+
c := s.bestAcceptableConnToPeer(ctx, p)
454452

455453
if c == nil {
456454
if nodial, _ := network.GetNoDial(ctx); nodial {
@@ -548,26 +546,17 @@ func (s *Swarm) bestConnToPeer(p peer.ID) *Conn {
548546
return best
549547
}
550548

551-
// - Returns the best "acceptable" connection, if available.
552-
// - Returns nothing if no such connection exists, but if we should try dialing anyways.
553-
// - Returns an error if no such connection exists, but we should not try dialing.
554-
func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) (*Conn, error) {
549+
// bestAcceptableConnToPeer returns the best acceptable connection in the ctx passed. If
550+
// network.WithForceDirectDial is used, it'll only returns a direct connection ignoring
551+
// any transient(relayed) connections to the peer.
552+
func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) *Conn {
555553
conn := s.bestConnToPeer(p)
556-
if conn == nil {
557-
return nil, nil
558-
}
559554

560555
forceDirect, _ := network.GetForceDirectDial(ctx)
561556
if forceDirect && !isDirectConn(conn) {
562-
return nil, nil
563-
}
564-
565-
useTransient, _ := network.GetUseTransient(ctx)
566-
if useTransient || !conn.Stat().Transient {
567-
return conn, nil
557+
return nil
568558
}
569-
570-
return nil, network.ErrTransientConn
559+
return conn
571560
}
572561

573562
func isDirectConn(c *Conn) bool {

p2p/net/swarm/swarm_dial.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ func (db *DialBackoff) cleanup() {
216216
}
217217
}
218218

219-
// DialPeer connects to a peer.
219+
// DialPeer connects to a peer. Use network.WithForceDirectDial to force a
220+
// direct connection.
220221
//
221222
// The idea is that the client of Swarm does not need to know what network
222223
// the connection will happen over. Swarm can use whichever it choses.
@@ -246,11 +247,10 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) {
246247
return nil, ErrDialToSelf
247248
}
248249

249-
// check if we already have an open (usable) connection first, or can't have a usable
250-
// connection.
251-
conn, err := s.bestAcceptableConnToPeer(ctx, p)
252-
if conn != nil || err != nil {
253-
return conn, err
250+
// check if we already have an open (usable) connection.
251+
conn := s.bestAcceptableConnToPeer(ctx, p)
252+
if conn != nil {
253+
return conn, nil
254254
}
255255

256256
if s.gater != nil && !s.gater.InterceptPeerDial(p) {

p2p/test/swarm/swarm_test.go

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package swarm_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/libp2p/go-libp2p"
8+
"github.com/libp2p/go-libp2p/core/network"
9+
"github.com/libp2p/go-libp2p/core/peer"
10+
"github.com/libp2p/go-libp2p/core/peerstore"
11+
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
12+
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
13+
ma "github.com/multiformats/go-multiaddr"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func TestDialPeerTransientConnection(t *testing.T) {
18+
h1, err := libp2p.New(
19+
libp2p.NoListenAddrs,
20+
libp2p.EnableRelay(),
21+
)
22+
require.NoError(t, err)
23+
24+
h2, err := libp2p.New(
25+
libp2p.NoListenAddrs,
26+
libp2p.EnableRelay(),
27+
)
28+
require.NoError(t, err)
29+
30+
relay1, err := libp2p.New()
31+
require.NoError(t, err)
32+
33+
_, err = relay.New(relay1)
34+
require.NoError(t, err)
35+
36+
relay1info := peer.AddrInfo{
37+
ID: relay1.ID(),
38+
Addrs: relay1.Addrs(),
39+
}
40+
err = h1.Connect(context.Background(), relay1info)
41+
require.NoError(t, err)
42+
43+
err = h2.Connect(context.Background(), relay1info)
44+
require.NoError(t, err)
45+
46+
_, err = client.Reserve(context.Background(), h2, relay1info)
47+
require.NoError(t, err)
48+
49+
relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String())
50+
51+
h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL)
52+
53+
// swarm.DialPeer should connect over transient connections
54+
conn1, err := h1.Network().DialPeer(context.Background(), h2.ID())
55+
require.NoError(t, err)
56+
require.NotNil(t, conn1)
57+
58+
conn2, err := h1.Network().DialPeer(context.Background(), h2.ID())
59+
require.NoError(t, err)
60+
require.NotNil(t, conn2)
61+
62+
require.Equal(t, conn1, conn2)
63+
64+
// swarm.DialPeer should fail if forceDirect is used
65+
ctx := network.WithForceDirectDial(context.Background(), "test")
66+
conn, err := h1.Network().DialPeer(ctx, h2.ID())
67+
require.Error(t, err)
68+
require.Nil(t, conn)
69+
}

0 commit comments

Comments
 (0)