diff --git a/conn.go b/conn.go index 76bf87e..c2203e3 100644 --- a/conn.go +++ b/conn.go @@ -3,6 +3,8 @@ package mysqlproto import ( "errors" "io" + "net" + "time" ) type Conn struct { @@ -12,10 +14,11 @@ type Conn struct { var ErrNoStream = errors.New("mysqlproto: stream is not set") -func ConnectPlainHandshake(rw io.ReadWriteCloser, capabilityFlags uint32, +func ConnectPlainHandshake(rw net.Conn, capabilityFlags uint32, username, password, database string, - connectAttrs map[string]string) (Conn, error) { - stream := NewStream(rw) + connectAttrs map[string]string, + readTimeout time.Duration) (Conn, error) { + stream := NewStream(rw, readTimeout) handshakeV10, err := ReadHandshakeV10(stream) if err != nil { return Conn{}, err diff --git a/conn_test.go b/conn_test.go index e4955ab..199fd17 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,7 @@ package mysqlproto import ( "errors" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -55,7 +56,7 @@ func TestConnCloseServerReplyERRPacket(t *testing.T) { 0x73, 0x65, 0x64, } buf := newBuffer(data) - conn := Conn{Stream: NewStream(buf), CapabilityFlags: CLIENT_PROTOCOL_41} + conn := Conn{Stream: NewStream(buf, time.Duration(0)), CapabilityFlags: CLIENT_PROTOCOL_41} err := conn.Close() assert.NotNil(t, err) assert.Equal(t, err.Error(), "mysqlproto: Error: 1096 SQLSTATE: HY000 Message: No tables used") @@ -69,7 +70,7 @@ func TestConnCloseServerReplyInvalidPacket(t *testing.T) { 0x48, 0x59, 0x30, 0x30, } buf := newBuffer(data) - conn := Conn{Stream: NewStream(buf)} + conn := Conn{Stream: NewStream(buf, time.Duration(0))} err := conn.Close() assert.NotNil(t, err) assert.Equal(t, err.Error(), "mysqlproto: invalid ERR_PACKET payload: dd48042348593030") diff --git a/handshake_v10_test.go b/handshake_v10_test.go index c4179b1..bb6cad0 100644 --- a/handshake_v10_test.go +++ b/handshake_v10_test.go @@ -2,6 +2,7 @@ package mysqlproto import ( "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -17,7 +18,7 @@ func TestNewHandshakeV10FullPacket(t *testing.T) { 0x76, 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, } stream := newBuffer(data) - packet, err := ReadHandshakeV10(NewStream(stream)) + packet, err := ReadHandshakeV10(NewStream(stream, time.Duration(0))) assert.Nil(t, err) assert.Equal(t, packet.ProtocolVersion, byte(0x0a)) assert.Equal(t, packet.ServerVersion, "5.6.25") @@ -35,7 +36,7 @@ func TestNewHandshakeV10ShortPacket(t *testing.T) { 0x00, 0x9e, 0x2e, 0x00, 0x00, 0x4f, 0x61, 0x7b, 0x65, 0x68, 0x5c, 0x73, 0x4e, 0x00, 0xff, 0xf7, }) - packet, err := ReadHandshakeV10(NewStream(buf)) + packet, err := ReadHandshakeV10(NewStream(buf, time.Duration(0))) assert.Nil(t, err) assert.Equal(t, packet.ProtocolVersion, byte(0x0a)) assert.Equal(t, packet.ServerVersion, "5.6.25") diff --git a/stream.go b/stream.go index fc0a27d..9b72e3b 100644 --- a/stream.go +++ b/stream.go @@ -2,21 +2,23 @@ package mysqlproto import ( "bytes" - "io" + "net" + "time" ) const PACKET_BUFFER_SIZE = 1500 // default MTU type Stream struct { - stream io.ReadWriteCloser + stream net.Conn buffer []byte read int left int syscalls int + readTimeout time.Duration } -func NewStream(stream io.ReadWriteCloser) *Stream { - return &Stream{stream, nil, 0, 0, 0} +func NewStream(stream net.Conn, readTimeout time.Duration) *Stream { + return &Stream{stream, nil, 0, 0, 0, readTimeout} } func (s *Stream) Write(data []byte) (int, error) { @@ -86,6 +88,12 @@ func (s *Stream) ResetStats() { func (s *Stream) readAtLeast(buf []byte, min int) (n int, err error) { for n < min && err == nil { var nn int + if s.readTimeout > 0 { + if err = s.stream.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil { + return + } + } + nn, err = s.stream.Read(buf[n:]) s.syscalls += 1 n += nn @@ -117,3 +125,12 @@ func (b *buffer) Write(data []byte) (int, error) { return b.writeFn(data) } +func (b *buffer) RemoteAddr() net.Addr { return MockAddr{} } +func (b *buffer) LocalAddr() net.Addr { return MockAddr{} } +func (b *buffer) SetDeadline(t time.Time) error { return nil} +func (b *buffer) SetReadDeadline(t time.Time) error { return nil} +func (b *buffer) SetWriteDeadline(t time.Time) error { return nil} + +type MockAddr struct {} +func (m MockAddr) Network() string { return "" } +func (m MockAddr) String() string { return "" } diff --git a/stream_test.go b/stream_test.go index fe1a9f3..c0a65cd 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,6 +2,7 @@ package mysqlproto import ( "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -13,7 +14,7 @@ func TestNextPacket7(t *testing.T) { 0x01, 0x02, 0x03, }) - stream := NewStream(buf) + stream := NewStream(buf, time.Duration(0)) packet, err := stream.NextPacket() assert.Nil(t, err) assert.Equal(t, packet.SequenceID, byte(0x02)) @@ -39,7 +40,7 @@ func TestNextPacket256(t *testing.T) { 0x00, 0x00, 0x00, 0x02, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x02, }) - stream := NewStream(buf) + stream := NewStream(buf, time.Duration(0)) packet, err := stream.NextPacket() assert.Nil(t, err) assert.Equal(t, packet.SequenceID, byte(0x02))