Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to time out TCP reads #2

Merged
merged 2 commits into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package mysqlproto
import (
"errors"
"io"
"net"
"time"
)

type Conn struct {
Expand All @@ -12,10 +14,11 @@ type Conn struct {

var ErrNoStream = errors.New("mysqlproto: stream is not set")

func ConnectPlainHandshake(rw io.ReadWriteCloser, capabilityFlags uint32,
func ConnectPlainHandshake(rw net.Conn, capabilityFlags uint32,
username, password, database string,
connectAttrs map[string]string) (Conn, error) {
stream := NewStream(rw)
connectAttrs map[string]string,
readTimeout time.Duration) (Conn, error) {
stream := NewStream(rw, readTimeout)
handshakeV10, err := ReadHandshakeV10(stream)
if err != nil {
return Conn{}, err
Expand Down
5 changes: 3 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mysqlproto
import (
"errors"
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -55,7 +56,7 @@ func TestConnCloseServerReplyERRPacket(t *testing.T) {
0x73, 0x65, 0x64,
}
buf := newBuffer(data)
conn := Conn{Stream: NewStream(buf), CapabilityFlags: CLIENT_PROTOCOL_41}
conn := Conn{Stream: NewStream(buf, time.Duration(0)), CapabilityFlags: CLIENT_PROTOCOL_41}
err := conn.Close()
assert.NotNil(t, err)
assert.Equal(t, err.Error(), "mysqlproto: Error: 1096 SQLSTATE: HY000 Message: No tables used")
Expand All @@ -69,7 +70,7 @@ func TestConnCloseServerReplyInvalidPacket(t *testing.T) {
0x48, 0x59, 0x30, 0x30,
}
buf := newBuffer(data)
conn := Conn{Stream: NewStream(buf)}
conn := Conn{Stream: NewStream(buf, time.Duration(0))}
err := conn.Close()
assert.NotNil(t, err)
assert.Equal(t, err.Error(), "mysqlproto: invalid ERR_PACKET payload: dd48042348593030")
Expand Down
5 changes: 3 additions & 2 deletions handshake_v10_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mysqlproto

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand All @@ -17,7 +18,7 @@ func TestNewHandshakeV10FullPacket(t *testing.T) {
0x76, 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
}
stream := newBuffer(data)
packet, err := ReadHandshakeV10(NewStream(stream))
packet, err := ReadHandshakeV10(NewStream(stream, time.Duration(0)))
assert.Nil(t, err)
assert.Equal(t, packet.ProtocolVersion, byte(0x0a))
assert.Equal(t, packet.ServerVersion, "5.6.25")
Expand All @@ -35,7 +36,7 @@ func TestNewHandshakeV10ShortPacket(t *testing.T) {
0x00, 0x9e, 0x2e, 0x00, 0x00, 0x4f, 0x61, 0x7b, 0x65, 0x68, 0x5c,
0x73, 0x4e, 0x00, 0xff, 0xf7,
})
packet, err := ReadHandshakeV10(NewStream(buf))
packet, err := ReadHandshakeV10(NewStream(buf, time.Duration(0)))
assert.Nil(t, err)
assert.Equal(t, packet.ProtocolVersion, byte(0x0a))
assert.Equal(t, packet.ServerVersion, "5.6.25")
Expand Down
25 changes: 21 additions & 4 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@ package mysqlproto

import (
"bytes"
"io"
"net"
"time"
)

const PACKET_BUFFER_SIZE = 1500 // default MTU

type Stream struct {
stream io.ReadWriteCloser
stream net.Conn
buffer []byte
read int
left int
syscalls int
readTimeout time.Duration
}

func NewStream(stream io.ReadWriteCloser) *Stream {
return &Stream{stream, nil, 0, 0, 0}
func NewStream(stream net.Conn, readTimeout time.Duration) *Stream {
return &Stream{stream, nil, 0, 0, 0, readTimeout}
}

func (s *Stream) Write(data []byte) (int, error) {
Expand Down Expand Up @@ -86,6 +88,12 @@ func (s *Stream) ResetStats() {
func (s *Stream) readAtLeast(buf []byte, min int) (n int, err error) {
for n < min && err == nil {
var nn int
if s.readTimeout > 0 {
if err = s.stream.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil {
return
}
}

nn, err = s.stream.Read(buf[n:])
s.syscalls += 1
n += nn
Expand Down Expand Up @@ -117,3 +125,12 @@ func (b *buffer) Write(data []byte) (int, error) {

return b.writeFn(data)
}
func (b *buffer) RemoteAddr() net.Addr { return MockAddr{} }
func (b *buffer) LocalAddr() net.Addr { return MockAddr{} }
func (b *buffer) SetDeadline(t time.Time) error { return nil}
func (b *buffer) SetReadDeadline(t time.Time) error { return nil}
func (b *buffer) SetWriteDeadline(t time.Time) error { return nil}

type MockAddr struct {}
func (m MockAddr) Network() string { return "" }
func (m MockAddr) String() string { return "" }
5 changes: 3 additions & 2 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mysqlproto

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand All @@ -13,7 +14,7 @@ func TestNextPacket7(t *testing.T) {
0x01, 0x02, 0x03,
})

stream := NewStream(buf)
stream := NewStream(buf, time.Duration(0))
packet, err := stream.NextPacket()
assert.Nil(t, err)
assert.Equal(t, packet.SequenceID, byte(0x02))
Expand All @@ -39,7 +40,7 @@ func TestNextPacket256(t *testing.T) {
0x00, 0x00, 0x00, 0x02, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x02,
})

stream := NewStream(buf)
stream := NewStream(buf, time.Duration(0))
packet, err := stream.NextPacket()
assert.Nil(t, err)
assert.Equal(t, packet.SequenceID, byte(0x02))
Expand Down