Skip to content

Commit 2d88d54

Browse files
committed
swarm: wait for transient connections to upgrade for NewStream
1 parent a29a92e commit 2d88d54

File tree

3 files changed

+168
-20
lines changed

3 files changed

+168
-20
lines changed

p2p/net/swarm/swarm.go

+93-18
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/libp2p/go-libp2p/core/peer"
1818
"github.com/libp2p/go-libp2p/core/peerstore"
1919
"github.com/libp2p/go-libp2p/core/transport"
20+
"golang.org/x/exp/slices"
2021

2122
logging "github.com/ipfs/go-log/v2"
2223
ma "github.com/multiformats/go-multiaddr"
@@ -172,6 +173,11 @@ type Swarm struct {
172173
m map[network.Notifiee]struct{}
173174
}
174175

176+
directConnNotifs struct {
177+
sync.Mutex
178+
m map[peer.ID][]chan struct{}
179+
}
180+
175181
transports struct {
176182
sync.RWMutex
177183
m map[int]transport.Transport
@@ -231,6 +237,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts
231237
s.listeners.m = make(map[transport.Listener]struct{})
232238
s.transports.m = make(map[int]transport.Transport)
233239
s.notifs.m = make(map[network.Notifiee]struct{})
240+
s.directConnNotifs.m = make(map[peer.ID][]chan struct{})
234241

235242
for _, opt := range opts {
236243
if err := opt(s); err != nil {
@@ -390,6 +397,19 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
390397
c.notifyLk.Lock()
391398
s.conns.Unlock()
392399

400+
// Notify goroutines waiting for a direct connection
401+
402+
// Go routines interested in waiting for direct connection first acquire this lock and then
403+
// acquire conns.RLock. Do not acquire this lock before conns.Unlock to prevent deadlock.
404+
s.directConnNotifs.Lock()
405+
if !c.Stat().Transient {
406+
for _, ch := range s.directConnNotifs.m[p] {
407+
close(ch)
408+
}
409+
delete(s.directConnNotifs.m, p)
410+
}
411+
s.directConnNotifs.Unlock()
412+
393413
// Emit event after releasing `s.conns` lock so that a consumer can still
394414
// use swarm methods that need the `s.conns` lock.
395415
if isFirstConnection {
@@ -435,46 +455,101 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error
435455

436456
// Algorithm:
437457
// 1. Find the best connection, otherwise, dial.
438-
// 2. Try opening a stream.
439-
// 3. If the underlying connection is, in fact, closed, close the outer
458+
// 2. If the best connection is transient, wait for a direct conn via conn
459+
// reversal or hole punching.
460+
// 3. Try opening a stream.
461+
// 4. If the underlying connection is, in fact, closed, close the outer
440462
// connection and try again. We do this in case we have a closed
441463
// connection but don't notice it until we actually try to open a
442464
// stream.
443465
//
444-
// Note: We only dial once.
445-
//
446466
// TODO: Try all connections even if we get an error opening a stream on
447467
// a non-closed connection.
448-
dials := 0
449-
for {
450-
// will prefer direct connections over relayed connections for opening streams
451-
c := s.bestAcceptableConnToPeer(ctx, p)
452-
468+
dialed := false
469+
for i := 0; i < 1; i++ {
470+
c := s.bestConnToPeer(p)
453471
if c == nil {
454-
if nodial, _ := network.GetNoDial(ctx); nodial {
472+
if nodial, _ := network.GetNoDial(ctx); !nodial {
473+
if dialed {
474+
return nil, errors.New("max dial attempts exceeded")
475+
}
476+
dialed = true
477+
var err error
478+
c, err = s.dialPeer(ctx, p)
479+
if err != nil {
480+
return nil, err
481+
}
482+
} else {
455483
return nil, network.ErrNoConn
456484
}
485+
}
457486

458-
if dials >= DialAttempts {
459-
return nil, errors.New("max dial attempts exceeded")
460-
}
461-
dials++
462-
487+
useTransient, _ := network.GetUseTransient(ctx)
488+
if !useTransient && c.Stat().Transient {
463489
var err error
464-
c, err = s.dialPeer(ctx, p)
490+
c, err = s.waitForDirectConn(ctx, p)
465491
if err != nil {
466492
return nil, err
467493
}
468494
}
469495

470-
s, err := c.NewStream(ctx)
496+
str, err := c.NewStream(ctx)
471497
if err != nil {
472498
if c.conn.IsClosed() {
473499
continue
474500
}
475501
return nil, err
476502
}
477-
return s, nil
503+
return str, nil
504+
}
505+
return nil, network.ErrNoConn
506+
}
507+
508+
// waitForDirectConn waits for a direct connection established through hole punching or connection reversal.
509+
func (s *Swarm) waitForDirectConn(ctx context.Context, p peer.ID) (*Conn, error) {
510+
s.directConnNotifs.Lock()
511+
c := s.bestConnToPeer(p)
512+
if c == nil {
513+
s.directConnNotifs.Unlock()
514+
return nil, network.ErrNoConn
515+
} else if !c.Stat().Transient {
516+
s.directConnNotifs.Unlock()
517+
return c, nil
518+
}
519+
520+
// Wait for transient connection to upgrade to a direct connection either by
521+
// connection reversal or hole punching.
522+
ch := make(chan struct{})
523+
s.directConnNotifs.m[p] = append(s.directConnNotifs.m[p], ch)
524+
s.directConnNotifs.Unlock()
525+
526+
// Wait for notification.
527+
// There's no point waiting for more than a minute here.
528+
ctx, cancel := context.WithTimeout(ctx, time.Minute)
529+
defer cancel()
530+
select {
531+
case <-ctx.Done():
532+
// Remove ourselves from the notification list
533+
s.directConnNotifs.Lock()
534+
s.directConnNotifs.m[p] = slices.DeleteFunc(
535+
s.directConnNotifs.m[p],
536+
func(c chan struct{}) bool { return c == ch },
537+
)
538+
if len(s.directConnNotifs.m[p]) == 0 {
539+
delete(s.directConnNotifs.m, p)
540+
}
541+
s.directConnNotifs.Unlock()
542+
return nil, ctx.Err()
543+
case <-ch:
544+
// We do not need to remove ourselves from the list here as the notifier
545+
// clears the map
546+
c := s.bestConnToPeer(p)
547+
if c == nil {
548+
return nil, network.ErrNoConn
549+
} else if c.Stat().Transient {
550+
return nil, network.ErrTransientConn
551+
}
552+
return c, nil
478553
}
479554
}
480555

p2p/test/basichost/basic_host_test.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"testing"
7+
"time"
78

89
"github.com/libp2p/go-libp2p"
910
"github.com/libp2p/go-libp2p/core/network"
@@ -62,10 +63,12 @@ func TestNoStreamOverTransientConnection(t *testing.T) {
6263
err = h1.Connect(context.Background(), h2Info)
6364
require.NoError(t, err)
6465

65-
ctx := network.WithNoDial(context.Background(), "test")
66+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
67+
defer cancel()
68+
ctx = network.WithNoDial(ctx, "test")
6669
_, err = h1.NewStream(ctx, h2.ID(), "/testprotocol")
6770

68-
require.ErrorIs(t, err, network.ErrTransientConn)
71+
require.Error(t, err)
6972

7073
_, err = h1.NewStream(network.WithUseTransient(context.Background(), "test"), h2.ID(), "/testprotocol")
7174
require.NoError(t, err)

p2p/test/swarm/swarm_test.go

+70
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"testing"
7+
"time"
78

89
"github.com/libp2p/go-libp2p"
910
"github.com/libp2p/go-libp2p/core/network"
@@ -12,6 +13,7 @@ import (
1213
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
1314
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
1415
ma "github.com/multiformats/go-multiaddr"
16+
"github.com/stretchr/testify/assert"
1517
"github.com/stretchr/testify/require"
1618
)
1719

@@ -75,3 +77,71 @@ func TestDialPeerTransientConnection(t *testing.T) {
7577
require.Error(t, err)
7678
require.Nil(t, conn)
7779
}
80+
81+
func TestNewStreamTransientConnection(t *testing.T) {
82+
h1, err := libp2p.New(
83+
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
84+
libp2p.EnableRelay(),
85+
)
86+
require.NoError(t, err)
87+
88+
h2, err := libp2p.New(
89+
libp2p.NoListenAddrs,
90+
libp2p.EnableRelay(),
91+
)
92+
require.NoError(t, err)
93+
94+
relay1, err := libp2p.New()
95+
require.NoError(t, err)
96+
97+
_, err = relay.New(relay1)
98+
require.NoError(t, err)
99+
100+
relay1info := peer.AddrInfo{
101+
ID: relay1.ID(),
102+
Addrs: relay1.Addrs(),
103+
}
104+
err = h1.Connect(context.Background(), relay1info)
105+
require.NoError(t, err)
106+
107+
err = h2.Connect(context.Background(), relay1info)
108+
require.NoError(t, err)
109+
110+
h2.SetStreamHandler("/testprotocol", func(s network.Stream) {
111+
fmt.Println("testprotocol")
112+
113+
// End the example
114+
s.Close()
115+
})
116+
117+
_, err = client.Reserve(context.Background(), h2, relay1info)
118+
require.NoError(t, err)
119+
120+
relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String())
121+
122+
h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL)
123+
124+
// NewStream should block transient connections till we have a direct connection
125+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
126+
defer cancel()
127+
s, err := h1.Network().NewStream(ctx, h2.ID())
128+
require.ErrorIs(t, err, context.DeadlineExceeded)
129+
require.Nil(t, s)
130+
131+
// NewStream should return a stream if a direct connection is established
132+
// while waiting
133+
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
134+
defer cancel()
135+
time.AfterFunc(time.Second, func() {
136+
// connect h2 to h1 simulating connection reversal
137+
h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.TempAddrTTL)
138+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
139+
defer cancel()
140+
ctx = network.WithForceDirectDial(ctx, "test")
141+
err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID()})
142+
assert.NoError(t, err)
143+
})
144+
s, err = h1.Network().NewStream(ctx, h2.ID())
145+
require.NoError(t, err)
146+
require.NotNil(t, s)
147+
}

0 commit comments

Comments
 (0)