Skip to content

Commit bc7b1ce

Browse files
author
Florian Hartwig
committedJul 19, 2019
Make it possible to time out TCP reads
1 parent 316d3ca commit bc7b1ce

File tree

5 files changed

+36
-13
lines changed

5 files changed

+36
-13
lines changed
 

‎conn.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package mysqlproto
33
import (
44
"errors"
55
"io"
6+
"net"
7+
"time"
68
)
79

810
type Conn struct {
@@ -12,10 +14,11 @@ type Conn struct {
1214

1315
var ErrNoStream = errors.New("mysqlproto: stream is not set")
1416

15-
func ConnectPlainHandshake(rw io.ReadWriteCloser, capabilityFlags uint32,
17+
func ConnectPlainHandshake(rw net.Conn, capabilityFlags uint32,
1618
username, password, database string,
17-
connectAttrs map[string]string) (Conn, error) {
18-
stream := NewStream(rw)
19+
connectAttrs map[string]string,
20+
readTimeout time.Duration) (Conn, error) {
21+
stream := NewStream(rw, readTimeout)
1922
handshakeV10, err := ReadHandshakeV10(stream)
2023
if err != nil {
2124
return Conn{}, err

‎conn_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mysqlproto
33
import (
44
"errors"
55
"testing"
6+
"time"
67

78
"github.com/stretchr/testify/assert"
89
)
@@ -55,7 +56,7 @@ func TestConnCloseServerReplyERRPacket(t *testing.T) {
5556
0x73, 0x65, 0x64,
5657
}
5758
buf := newBuffer(data)
58-
conn := Conn{Stream: NewStream(buf), CapabilityFlags: CLIENT_PROTOCOL_41}
59+
conn := Conn{Stream: NewStream(buf, time.Duration(0)), CapabilityFlags: CLIENT_PROTOCOL_41}
5960
err := conn.Close()
6061
assert.NotNil(t, err)
6162
assert.Equal(t, err.Error(), "mysqlproto: Error: 1096 SQLSTATE: HY000 Message: No tables used")
@@ -69,7 +70,7 @@ func TestConnCloseServerReplyInvalidPacket(t *testing.T) {
6970
0x48, 0x59, 0x30, 0x30,
7071
}
7172
buf := newBuffer(data)
72-
conn := Conn{Stream: NewStream(buf)}
73+
conn := Conn{Stream: NewStream(buf, time.Duration(0))}
7374
err := conn.Close()
7475
assert.NotNil(t, err)
7576
assert.Equal(t, err.Error(), "mysqlproto: invalid ERR_PACKET payload: dd48042348593030")

‎handshake_v10_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mysqlproto
22

33
import (
44
"testing"
5+
"time"
56

67
"github.com/stretchr/testify/assert"
78
)
@@ -17,7 +18,7 @@ func TestNewHandshakeV10FullPacket(t *testing.T) {
1718
0x76, 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
1819
}
1920
stream := newBuffer(data)
20-
packet, err := ReadHandshakeV10(NewStream(stream))
21+
packet, err := ReadHandshakeV10(NewStream(stream, time.Duration(0)))
2122
assert.Nil(t, err)
2223
assert.Equal(t, packet.ProtocolVersion, byte(0x0a))
2324
assert.Equal(t, packet.ServerVersion, "5.6.25")
@@ -35,7 +36,7 @@ func TestNewHandshakeV10ShortPacket(t *testing.T) {
3536
0x00, 0x9e, 0x2e, 0x00, 0x00, 0x4f, 0x61, 0x7b, 0x65, 0x68, 0x5c,
3637
0x73, 0x4e, 0x00, 0xff, 0xf7,
3738
})
38-
packet, err := ReadHandshakeV10(NewStream(buf))
39+
packet, err := ReadHandshakeV10(NewStream(buf, time.Duration(0)))
3940
assert.Nil(t, err)
4041
assert.Equal(t, packet.ProtocolVersion, byte(0x0a))
4142
assert.Equal(t, packet.ServerVersion, "5.6.25")

‎stream.go

+21-4
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,23 @@ package mysqlproto
22

33
import (
44
"bytes"
5-
"io"
5+
"net"
6+
"time"
67
)
78

89
const PACKET_BUFFER_SIZE = 1500 // default MTU
910

1011
type Stream struct {
11-
stream io.ReadWriteCloser
12+
stream net.Conn
1213
buffer []byte
1314
read int
1415
left int
1516
syscalls int
17+
ReadTimeout time.Duration
1618
}
1719

18-
func NewStream(stream io.ReadWriteCloser) *Stream {
19-
return &Stream{stream, nil, 0, 0, 0}
20+
func NewStream(stream net.Conn, readTimeout time.Duration) *Stream {
21+
return &Stream{stream, nil, 0, 0, 0, readTimeout}
2022
}
2123

2224
func (s *Stream) Write(data []byte) (int, error) {
@@ -86,6 +88,12 @@ func (s *Stream) ResetStats() {
8688
func (s *Stream) readAtLeast(buf []byte, min int) (n int, err error) {
8789
for n < min && err == nil {
8890
var nn int
91+
if s.ReadTimeout > 0 {
92+
if err = s.stream.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil {
93+
return
94+
}
95+
}
96+
8997
nn, err = s.stream.Read(buf[n:])
9098
s.syscalls += 1
9199
n += nn
@@ -117,3 +125,12 @@ func (b *buffer) Write(data []byte) (int, error) {
117125

118126
return b.writeFn(data)
119127
}
128+
func (b *buffer) RemoteAddr() net.Addr { return MockAddr{} }
129+
func (b *buffer) LocalAddr() net.Addr { return MockAddr{} }
130+
func (b *buffer) SetDeadline(t time.Time) error { return nil}
131+
func (b *buffer) SetReadDeadline(t time.Time) error { return nil}
132+
func (b *buffer) SetWriteDeadline(t time.Time) error { return nil}
133+
134+
type MockAddr struct {}
135+
func (m MockAddr) Network() string { return "" }
136+
func (m MockAddr) String() string { return "" }

‎stream_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mysqlproto
22

33
import (
44
"testing"
5+
"time"
56

67
"github.com/stretchr/testify/assert"
78
)
@@ -13,7 +14,7 @@ func TestNextPacket7(t *testing.T) {
1314
0x01, 0x02, 0x03,
1415
})
1516

16-
stream := NewStream(buf)
17+
stream := NewStream(buf, time.Duration(0))
1718
packet, err := stream.NextPacket()
1819
assert.Nil(t, err)
1920
assert.Equal(t, packet.SequenceID, byte(0x02))
@@ -39,7 +40,7 @@ func TestNextPacket256(t *testing.T) {
3940
0x00, 0x00, 0x00, 0x02, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x02,
4041
})
4142

42-
stream := NewStream(buf)
43+
stream := NewStream(buf, time.Duration(0))
4344
packet, err := stream.NextPacket()
4445
assert.Nil(t, err)
4546
assert.Equal(t, packet.SequenceID, byte(0x02))

0 commit comments

Comments
 (0)
Please sign in to comment.