Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (c *Conn) readInitialHandshake() error {
pos += 2

// The upper 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
c.capability |= uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2

// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
Expand Down Expand Up @@ -209,10 +209,8 @@ func (c *Conn) writeAuthHandshake() error {

// Set default client capabilities that reflect the abilities of this library
capability := mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SECURE_CONNECTION |
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH
// Adjust client capability flags based on server support
capability |= c.capability & mysql.CLIENT_LONG_FLAG
capability |= c.capability & mysql.CLIENT_QUERY_ATTRIBUTES
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH |
mysql.CLIENT_LONG_FLAG | mysql.CLIENT_QUERY_ATTRIBUTES | mysql.CLIENT_DEPRECATE_EOF
// Adjust client capability flags on specific client requests
// Only flags that would make any sense setting and aren't handled elsewhere
// in the library are supported here
Expand Down Expand Up @@ -275,6 +273,7 @@ func (c *Conn) writeAuthHandshake() error {
data := make([]byte, length+4)

// capability [32 bit]
c.capability &= capability
data[4] = byte(capability)
data[5] = byte(capability >> 8)
data[6] = byte(capability >> 16)
Expand Down
13 changes: 13 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ func (s *clientTestSuite) TestConn_Compress() {
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_NoDeprecateEOF() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
conn.UnsetCapability(mysql.CLIENT_DEPRECATE_EOF)
return nil
})
require.NoError(s.T(), err)

_, err = conn.Execute("SELECT VERSION()")
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCapability() {
caps := []uint32{
mysql.CLIENT_LONG_PASSWORD,
Expand All @@ -125,6 +137,7 @@ func (s *clientTestSuite) TestConn_SetCapability() {
mysql.CLIENT_PLUGIN_AUTH,
mysql.CLIENT_CONNECT_ATTRS,
mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA,
mysql.CLIENT_DEPRECATE_EOF,
}

for _, capI := range caps {
Expand Down
4 changes: 2 additions & 2 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (c *Conn) UnsetCapability(cap uint32) {

// HasCapability returns true if the connection has the specific capability
func (c *Conn) HasCapability(cap uint32) bool {
return c.ccaps&cap > 0
return c.ccaps&cap != 0
}

// UseSSL: use default SSL
Expand Down Expand Up @@ -466,7 +466,7 @@ func (c *Conn) FieldList(table string, wildcard string) ([]*mysql.Field, error)
}

// EOF Packet
if c.isEOFPacket(data) {
if data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we change it to len(data) <= 0xffffff?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add comment later, bit unavailable this week

https://dev.mysql.com/worklog/task/?id=7766

In case of huge data packet with length greater than 16777216L client will treat it as a data packet and process accordingly.

return fs, nil
}

Expand Down
87 changes: 45 additions & 42 deletions client/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,7 @@ import (
"github.com/go-mysql-org/go-mysql/utils"
)

func (c *Conn) readUntilEOF() (err error) {
var data []byte

for {
data, err = c.ReadPacket()
if err != nil {
return err
}

// EOF Packet
if c.isEOFPacket(data) {
return err
}
}
}

// this should only be called when CLIENT_DEPRECATE_EOF not enabled
func (c *Conn) isEOFPacket(data []byte) bool {
return data[0] == mysql.EOF_HEADER && len(data) <= 5
}
Expand Down Expand Up @@ -336,33 +321,16 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *mysql.Re
}

func (c *Conn) readResultColumns(result *mysql.Result) (err error) {
i := 0
var data []byte

for {
for i := range result.Fields {
rawPkgLen := len(result.RawPkg)
result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
if err != nil {
return err
}
data = result.RawPkg[rawPkgLen:]

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}

if i != len(result.Fields) {
err = mysql.ErrMalformPacket
}

return err
}

if result.Fields[i] == nil {
result.Fields[i] = &mysql.Field{}
}
Expand All @@ -372,8 +340,30 @@ func (c *Conn) readResultColumns(result *mysql.Result) (err error) {
}

result.FieldNames[utils.ByteSliceToString(result.Fields[i].Name)] = i
}

i++
if c.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
// EOF Packet
rawPkgLen := len(result.RawPkg)
result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
if err != nil {
return err
}
data = result.RawPkg[rawPkgLen:]

if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}
return nil
} else {
return mysql.ErrMalformPacket
}
} else {
return nil
}
}

Expand All @@ -388,15 +378,21 @@ func (c *Conn) readResultRows(result *mysql.Result, isBinary bool) (err error) {
}
data = result.RawPkg[rawPkgLen:]

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
if data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff {
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
// Treat like OK
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
result.Status = binary.LittleEndian.Uint16(data[1+n+m:])
result.AffectedRows = affectedRows
result.InsertId = insertId
c.status = result.Status
} else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}

break
}

Expand Down Expand Up @@ -435,9 +431,16 @@ func (c *Conn) readResultRowsStreaming(result *mysql.Result, isBinary bool, perR
return err
}

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
if data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff {
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
// Treat like OK
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
result.Status = binary.LittleEndian.Uint16(data[1+n+m:])
result.AffectedRows = affectedRows
result.InsertId = insertId
c.status = result.Status
} else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
Expand Down
27 changes: 23 additions & 4 deletions client/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,33 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
}

if s.params > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
for range s.params {
if _, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
}
}
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
if packet, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
} else if !c.isEOFPacket(packet) {
return nil, mysql.ErrMalformPacket
}
}
}

if s.columns > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
// TODO process when CLIENT_CACHE_METADATA enabled
for range s.columns {
if _, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
}
}
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
if packet, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
} else if !c.isEOFPacket(packet) {
return nil, mysql.ErrMalformPacket
}
}
}

Expand Down
Loading