@@ -15,9 +15,10 @@ import (
15
15
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
16
16
"github.com/libp2p/go-msgio/pbio"
17
17
18
+ "math/rand"
19
+
18
20
ma "github.com/multiformats/go-multiaddr"
19
21
manet "github.com/multiformats/go-multiaddr/net"
20
- "golang.org/x/exp/rand"
21
22
)
22
23
23
24
type dataRequestPolicyFunc = func (s network.Stream , dialAddr ma.Multiaddr ) bool
@@ -32,7 +33,8 @@ type server struct {
32
33
33
34
// dialDataRequestPolicy is used to determine whether dialing the address requires receiving
34
35
// dial data. It is set to amplification attack prevention by default.
35
- dialDataRequestPolicy dataRequestPolicyFunc
36
+ dialDataRequestPolicy dataRequestPolicyFunc
37
+ amplificatonAttackPreventionDialWait time.Duration
36
38
37
39
// for tests
38
40
now func () time.Time
@@ -41,10 +43,11 @@ type server struct {
41
43
42
44
func newServer (host , dialer host.Host , s * autoNATSettings ) * server {
43
45
return & server {
44
- dialerHost : dialer ,
45
- host : host ,
46
- dialDataRequestPolicy : s .dataRequestPolicy ,
47
- allowPrivateAddrs : s .allowPrivateAddrs ,
46
+ dialerHost : dialer ,
47
+ host : host ,
48
+ dialDataRequestPolicy : s .dataRequestPolicy ,
49
+ amplificatonAttackPreventionDialWait : s .amplificatonAttackPreventionDialWait ,
50
+ allowPrivateAddrs : s .allowPrivateAddrs ,
48
51
limiter : & rateLimiter {
49
52
RPM : s .serverRPM ,
50
53
PerPeerRPM : s .serverPerPeerRPM ,
@@ -81,6 +84,9 @@ func (as *server) handleDialRequest(s network.Stream) {
81
84
}
82
85
defer s .Scope ().ReleaseMemory (maxMsgSize )
83
86
87
+ deadline := as .now ().Add (streamTimeout )
88
+ ctx , cancel := context .WithDeadline (context .Background (), deadline )
89
+ defer cancel ()
84
90
s .SetDeadline (as .now ().Add (streamTimeout ))
85
91
defer s .Close ()
86
92
@@ -183,9 +189,20 @@ func (as *server) handleDialRequest(s network.Stream) {
183
189
log .Debugf ("%s refused dial data request: %s" , p , err )
184
190
return
185
191
}
192
+ // wait for a bit to prevent thundering herd style attacks on a victim
193
+ waitTime := time .Duration (rand .Intn (int (as .amplificatonAttackPreventionDialWait ) + 1 )) // the range is [0, n)
194
+ t := time .NewTimer (waitTime )
195
+ defer t .Stop ()
196
+ select {
197
+ case <- ctx .Done ():
198
+ s .Reset ()
199
+ log .Debugf ("rejecting request without dialing: %s %p " , p , ctx .Err ())
200
+ return
201
+ case <- t .C :
202
+ }
186
203
}
187
204
188
- dialStatus := as .dialBack (s .Conn ().RemotePeer (), dialAddr , nonce )
205
+ dialStatus := as .dialBack (ctx , s .Conn ().RemotePeer (), dialAddr , nonce )
189
206
msg = pb.Message {
190
207
Msg : & pb.Message_DialResponse {
191
208
DialResponse : & pb.DialResponse {
@@ -252,8 +269,8 @@ func readDialData(numBytes int, r io.Reader) error {
252
269
return nil
253
270
}
254
271
255
- func (as * server ) dialBack (p peer.ID , addr ma.Multiaddr , nonce uint64 ) pb.DialStatus {
256
- ctx , cancel := context .WithTimeout (context . Background () , dialBackDialTimeout )
272
+ func (as * server ) dialBack (ctx context. Context , p peer.ID , addr ma.Multiaddr , nonce uint64 ) pb.DialStatus {
273
+ ctx , cancel := context .WithTimeout (ctx , dialBackDialTimeout )
257
274
ctx = network .WithForceDirectDial (ctx , "autonatv2" )
258
275
as .dialerHost .Peerstore ().AddAddr (p , addr , peerstore .TempAddrTTL )
259
276
defer func () {
0 commit comments